Skip to content

Commit 44e5400

Browse files
committed
Merge branch 'develop'
2 parents a6acf47 + c25791c commit 44e5400

File tree

3 files changed

+12
-6
lines changed

3 files changed

+12
-6
lines changed
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
beautifulsoup4
1+
beautifulsoup4
2+
tf2onnx==1.14.0
3+
onnx==1.12.0

Tutorials/09_translation_transformer/test.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import time
23

34
from mltu.tokenizers import CustomTokenizer
45
from mltu.inferenceModel import OnnxInferenceModel
@@ -12,6 +13,7 @@ def __init__(self, *args, **kwargs):
1213
self.detokenizer = CustomTokenizer.load(self.metadata["detokenizer"])
1314

1415
def predict(self, sentence):
16+
start = time.time()
1517
tokenized_sentence = self.tokenizer.texts_to_sequences([sentence])[0]
1618
encoder_input = np.pad(tokenized_sentence, (0, self.tokenizer.max_length - len(tokenized_sentence)), constant_values=0).astype(np.int64)
1719

@@ -30,8 +32,7 @@ def predict(self, sentence):
3032
break
3133

3234
results = self.detokenizer.detokenize([tokenized_results])
33-
return results[0]
34-
35+
return results[0], time.time() - start
3536

3637
def read_files(path):
3738
with open(path, "r", encoding="utf-8") as f:
@@ -49,11 +50,12 @@ def read_files(path):
4950
max_lenght = 500
5051
val_examples = [[es_sentence, en_sentence] for es_sentence, en_sentence in zip(es_validation_data, en_validation_data) if len(es_sentence) <= max_lenght and len(en_sentence) <= max_lenght]
5152

52-
translator = PtEnTranslator("Models/09_translation_transformer/202307241748/model.onnx")
53+
translator = PtEnTranslator("Models/09_translation_transformer/202308241514/model.onnx")
5354

5455
val_dataset = []
5556
for es, en in val_examples:
56-
results = translator.predict(es)
57-
print(en)
57+
results, duration = translator.predict(es)
58+
print(en.lower())
5859
print(results)
60+
print(duration)
5961
print()

Tutorials/09_translation_transformer/train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ def preprocess_inputs(data_batch, label_batch):
123123
model2onnx = Model2onnx(f"{configs.model_path}/model.h5", metadata={"tokenizer": tokenizer.dict(), "detokenizer": detokenizer.dict()}, save_on_epoch_end=False)
124124
encDecSplitCallback = EncDecSplitCallback(configs.model_path, encoder_metadata={"tokenizer": tokenizer.dict()}, decoder_metadata={"detokenizer": detokenizer.dict()})
125125

126+
configs.save()
127+
126128
# Train the model
127129
transformer.fit(
128130
train_dataProvider,

0 commit comments

Comments
 (0)