Skip to content

Commit c25791c

Browse files
committed
Include test code in transofmer translation tutorial
1 parent df58e7c commit c25791c

File tree

1 file changed

+7
-5
lines changed
  • Tutorials/09_translation_transformer

1 file changed

+7
-5
lines changed

Tutorials/09_translation_transformer/test.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import time
23

34
from mltu.tokenizers import CustomTokenizer
45
from mltu.inferenceModel import OnnxInferenceModel
@@ -12,6 +13,7 @@ def __init__(self, *args, **kwargs):
1213
self.detokenizer = CustomTokenizer.load(self.metadata["detokenizer"])
1314

1415
def predict(self, sentence):
16+
start = time.time()
1517
tokenized_sentence = self.tokenizer.texts_to_sequences([sentence])[0]
1618
encoder_input = np.pad(tokenized_sentence, (0, self.tokenizer.max_length - len(tokenized_sentence)), constant_values=0).astype(np.int64)
1719

@@ -30,8 +32,7 @@ def predict(self, sentence):
3032
break
3133

3234
results = self.detokenizer.detokenize([tokenized_results])
33-
return results[0]
34-
35+
return results[0], time.time() - start
3536

3637
def read_files(path):
3738
with open(path, "r", encoding="utf-8") as f:
@@ -49,11 +50,12 @@ def read_files(path):
4950
max_lenght = 500
5051
val_examples = [[es_sentence, en_sentence] for es_sentence, en_sentence in zip(es_validation_data, en_validation_data) if len(es_sentence) <= max_lenght and len(en_sentence) <= max_lenght]
5152

52-
translator = PtEnTranslator("Models/09_translation_transformer/202307241748/model.onnx")
53+
translator = PtEnTranslator("Models/09_translation_transformer/202308241514/model.onnx")
5354

5455
val_dataset = []
5556
for es, en in val_examples:
56-
results = translator.predict(es)
57-
print(en)
57+
results, duration = translator.predict(es)
58+
print(en.lower())
5859
print(results)
60+
print(duration)
5961
print()

0 commit comments

Comments
 (0)