Skip to content

Commit 5a7c747

Browse files
committed
Translation transformer tutorial
1 parent 6efbd7b commit 5a7c747

File tree

4 files changed

+2743
-34
lines changed

4 files changed

+2743
-34
lines changed

Tutorials/09_translation_transformer/test.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,30 +3,26 @@
33
try: [tf.config.experimental.set_memory_growth(gpu, True) for gpu in tf.config.experimental.list_physical_devices("GPU")]
44
except: pass
55
import numpy as np
6-
import json
7-
import tensorflow_datasets as tfds
86

97
from mltu.tokenizers import CustomTokenizer
108

11-
examples, metadata = tfds.load('ted_hrlr_translate/pt_to_en', with_info=True, as_supervised=True)
12-
139
class PtEnTranslator(OnnxInferenceModel):
1410
def __init__(self, *args, **kwargs):
1511
super().__init__(*args, **kwargs)
1612

1713
self.new_inputs = self.model.get_inputs()
18-
self.pt_tokenizer = CustomTokenizer.load(self.metadata["pt_tokenizer"])
19-
self.eng_tokenizer = CustomTokenizer.load(self.metadata["eng_tokenizer"])
14+
self.tokenizer = CustomTokenizer.load(self.metadata["tokenizer"])
15+
self.detokenizer = CustomTokenizer.load(self.metadata["detokenizer"])
2016
# self.eng_tokenizer = CustomTokenizer.load("Tutorials/09_transformers/eng_tokenizer.json")
2117
# self.pt_tokenizer = CustomTokenizer.load("Tutorials/09_transformers/pt_tokenizer.json")
2218

2319
def predict(self, sentence):
24-
tokenized_sentence = self.pt_tokenizer.texts_to_sequences([sentence])[0]
25-
encoder_input = np.pad(tokenized_sentence, (0, self.pt_tokenizer.max_length - len(tokenized_sentence)), constant_values=0).astype(np.int64)
20+
tokenized_sentence = self.tokenizer.texts_to_sequences([sentence])[0]
21+
encoder_input = np.pad(tokenized_sentence, (0, self.tokenizer.max_length - len(tokenized_sentence)), constant_values=0).astype(np.int64)
2622

27-
tokenized_results = [self.eng_tokenizer.start_token_index]
28-
for index in range(self.eng_tokenizer.max_length - 1):
29-
decoder_input = np.pad(tokenized_results, (0, self.eng_tokenizer.max_length - len(tokenized_results)), constant_values=0).astype(np.int64)
23+
tokenized_results = [self.detokenizer.start_token_index]
24+
for index in range(self.detokenizer.max_length - 1):
25+
decoder_input = np.pad(tokenized_results, (0, self.detokenizer.max_length - len(tokenized_results)), constant_values=0).astype(np.int64)
3026
input_dict = {
3127
self.model._inputs_meta[0].name: np.expand_dims(encoder_input, axis=0),
3228
self.model._inputs_meta[1].name: np.expand_dims(decoder_input, axis=0),
@@ -35,24 +31,41 @@ def predict(self, sentence):
3531
pred_results = np.argmax(preds, axis=2)
3632
tokenized_results.append(pred_results[0][index])
3733

38-
if tokenized_results[-1] == self.eng_tokenizer.end_token_index:
34+
if tokenized_results[-1] == self.detokenizer.end_token_index:
3935
break
4036

41-
results = self.eng_tokenizer.detokenize([tokenized_results])
37+
results = self.detokenizer.detokenize([tokenized_results])
4238
return results[0]
4339

4440

45-
translator = PtEnTranslator("Models/09_translation_transformer/202307101211/model.onnx")
41+
def read_files(path):
42+
with open(path, "r", encoding="utf-8") as f:
43+
en_train_dataset = f.read().split("\n")[:-1]
44+
return en_train_dataset
45+
46+
# Path to dataset
47+
en_validation_data_path = "Datasets/en-es/opus.en-es-dev.en"
48+
es_validation_data_path = "Datasets/en-es/opus.en-es-dev.es"
49+
50+
en_validation_data = read_files(en_validation_data_path)
51+
es_validation_data = read_files(es_validation_data_path)
52+
53+
# Consider only sentences with length <= 500
54+
max_lenght = 500
55+
val_examples = [[es_sentence, en_sentence] for es_sentence, en_sentence in zip(es_validation_data, en_validation_data) if len(es_sentence) <= max_lenght and len(en_sentence) <= max_lenght]
56+
# es_validation_data, en_validation_data = zip(*val_dataset)
57+
58+
59+
60+
4661

62+
translator = PtEnTranslator("Models/09_translation_transformer/202307241748/model.onnx")
4763

48-
train_examples, val_examples = examples['train'], examples['validation']
4964

5065
val_dataset = []
51-
for pt, en in val_examples:
52-
pt_sentence = pt.numpy().decode('utf-8')
53-
en_sentence = en.numpy().decode('utf-8')
54-
results = translator.predict(pt_sentence)
55-
print(en_sentence)
66+
for es, en in val_examples:
67+
results = translator.predict(es)
68+
print(en)
5669
print(results)
5770
print()
5871
# val_dataset.append([pt.numpy().decode('utf-8'), en.numpy().decode('utf-8')])

Tutorials/09_translation_transformer/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def preprocess_inputs(data_batch, label_batch):
136136
checkpoint,
137137
tb_callback,
138138
reduceLROnPlat,
139-
model2onnx
139+
model2onnx,
140+
encDecSplitCallback
140141
]
141142
)

0 commit comments

Comments
 (0)