Skip to content

Commit 0fab117

Browse files
committed
Transformer training code
1 parent 449fc95 commit 0fab117

File tree

2 files changed

+3
-15
lines changed

2 files changed

+3
-15
lines changed

Tutorials/09_translation_transformer/test.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
from mltu.inferenceModel import OnnxInferenceModel
2-
import tensorflow as tf
3-
try: [tf.config.experimental.set_memory_growth(gpu, True) for gpu in tf.config.experimental.list_physical_devices("GPU")]
4-
except: pass
51
import numpy as np
62

73
from mltu.tokenizers import CustomTokenizer
4+
from mltu.inferenceModel import OnnxInferenceModel
85

96
class PtEnTranslator(OnnxInferenceModel):
107
def __init__(self, *args, **kwargs):
@@ -13,8 +10,6 @@ def __init__(self, *args, **kwargs):
1310
self.new_inputs = self.model.get_inputs()
1411
self.tokenizer = CustomTokenizer.load(self.metadata["tokenizer"])
1512
self.detokenizer = CustomTokenizer.load(self.metadata["detokenizer"])
16-
# self.eng_tokenizer = CustomTokenizer.load("Tutorials/09_transformers/eng_tokenizer.json")
17-
# self.pt_tokenizer = CustomTokenizer.load("Tutorials/09_transformers/pt_tokenizer.json")
1813

1914
def predict(self, sentence):
2015
tokenized_sentence = self.tokenizer.texts_to_sequences([sentence])[0]
@@ -53,19 +48,12 @@ def read_files(path):
5348
# Consider only sentences with length <= 500
5449
max_lenght = 500
5550
val_examples = [[es_sentence, en_sentence] for es_sentence, en_sentence in zip(es_validation_data, en_validation_data) if len(es_sentence) <= max_lenght and len(en_sentence) <= max_lenght]
56-
# es_validation_data, en_validation_data = zip(*val_dataset)
57-
58-
59-
60-
6151

6252
translator = PtEnTranslator("Models/09_translation_transformer/202307241748/model.onnx")
6353

64-
6554
val_dataset = []
6655
for es, en in val_examples:
6756
results = translator.predict(es)
6857
print(en)
6958
print(results)
70-
print()
71-
# val_dataset.append([pt.numpy().decode('utf-8'), en.numpy().decode('utf-8')])
59+
print()

mltu/transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def __init__(
159159
self.padding_value = padding_value
160160

161161
def __call__(self, spectrogram: np.ndarray, label: np.ndarray):
162-
padded_spectrogram = np.pad(spectrogram, (0, (self.max_spectrogram_length - spectrogram.shape[0]),(0,0)), mode="constant", constant_values=self.padding_value)
162+
padded_spectrogram = np.pad(spectrogram, ((0, self.max_spectrogram_length - spectrogram.shape[0]),(0,0)), mode="constant", constant_values=self.padding_value)
163163

164164
return padded_spectrogram, label
165165

0 commit comments

Comments
 (0)