33try : [tf .config .experimental .set_memory_growth (gpu , True ) for gpu in tf .config .experimental .list_physical_devices ("GPU" )]
44except : pass
55import numpy as np
6- import json
7- import tensorflow_datasets as tfds
86
97from mltu .tokenizers import CustomTokenizer
108
11- examples , metadata = tfds .load ('ted_hrlr_translate/pt_to_en' , with_info = True , as_supervised = True )
12-
139class PtEnTranslator (OnnxInferenceModel ):
1410 def __init__ (self , * args , ** kwargs ):
1511 super ().__init__ (* args , ** kwargs )
1612
1713 self .new_inputs = self .model .get_inputs ()
18- self .pt_tokenizer = CustomTokenizer .load (self .metadata ["pt_tokenizer " ])
19- self .eng_tokenizer = CustomTokenizer .load (self .metadata ["eng_tokenizer " ])
14+ self .tokenizer = CustomTokenizer .load (self .metadata ["tokenizer " ])
15+ self .detokenizer = CustomTokenizer .load (self .metadata ["detokenizer " ])
2016 # self.eng_tokenizer = CustomTokenizer.load("Tutorials/09_transformers/eng_tokenizer.json")
2117 # self.pt_tokenizer = CustomTokenizer.load("Tutorials/09_transformers/pt_tokenizer.json")
2218
2319 def predict (self , sentence ):
24- tokenized_sentence = self .pt_tokenizer .texts_to_sequences ([sentence ])[0 ]
25- encoder_input = np .pad (tokenized_sentence , (0 , self .pt_tokenizer .max_length - len (tokenized_sentence )), constant_values = 0 ).astype (np .int64 )
20+ tokenized_sentence = self .tokenizer .texts_to_sequences ([sentence ])[0 ]
21+ encoder_input = np .pad (tokenized_sentence , (0 , self .tokenizer .max_length - len (tokenized_sentence )), constant_values = 0 ).astype (np .int64 )
2622
27- tokenized_results = [self .eng_tokenizer .start_token_index ]
28- for index in range (self .eng_tokenizer .max_length - 1 ):
29- decoder_input = np .pad (tokenized_results , (0 , self .eng_tokenizer .max_length - len (tokenized_results )), constant_values = 0 ).astype (np .int64 )
23+ tokenized_results = [self .detokenizer .start_token_index ]
24+ for index in range (self .detokenizer .max_length - 1 ):
25+ decoder_input = np .pad (tokenized_results , (0 , self .detokenizer .max_length - len (tokenized_results )), constant_values = 0 ).astype (np .int64 )
3026 input_dict = {
3127 self .model ._inputs_meta [0 ].name : np .expand_dims (encoder_input , axis = 0 ),
3228 self .model ._inputs_meta [1 ].name : np .expand_dims (decoder_input , axis = 0 ),
@@ -35,24 +31,41 @@ def predict(self, sentence):
3531 pred_results = np .argmax (preds , axis = 2 )
3632 tokenized_results .append (pred_results [0 ][index ])
3733
38- if tokenized_results [- 1 ] == self .eng_tokenizer .end_token_index :
34+ if tokenized_results [- 1 ] == self .detokenizer .end_token_index :
3935 break
4036
41- results = self .eng_tokenizer .detokenize ([tokenized_results ])
37+ results = self .detokenizer .detokenize ([tokenized_results ])
4238 return results [0 ]
4339
4440
45- translator = PtEnTranslator ("Models/09_translation_transformer/202307101211/model.onnx" )
41+ def read_files (path ):
42+ with open (path , "r" , encoding = "utf-8" ) as f :
43+ en_train_dataset = f .read ().split ("\n " )[:- 1 ]
44+ return en_train_dataset
45+
46+ # Path to dataset
47+ en_validation_data_path = "Datasets/en-es/opus.en-es-dev.en"
48+ es_validation_data_path = "Datasets/en-es/opus.en-es-dev.es"
49+
50+ en_validation_data = read_files (en_validation_data_path )
51+ es_validation_data = read_files (es_validation_data_path )
52+
53+ # Consider only sentences with length <= 500
54+ max_lenght = 500
55+ val_examples = [[es_sentence , en_sentence ] for es_sentence , en_sentence in zip (es_validation_data , en_validation_data ) if len (es_sentence ) <= max_lenght and len (en_sentence ) <= max_lenght ]
56+ # es_validation_data, en_validation_data = zip(*val_dataset)
57+
58+
59+
60+
4661
62+ translator = PtEnTranslator ("Models/09_translation_transformer/202307241748/model.onnx" )
4763
48- train_examples , val_examples = examples ['train' ], examples ['validation' ]
4964
5065val_dataset = []
51- for pt , en in val_examples :
52- pt_sentence = pt .numpy ().decode ('utf-8' )
53- en_sentence = en .numpy ().decode ('utf-8' )
54- results = translator .predict (pt_sentence )
55- print (en_sentence )
66+ for es , en in val_examples :
67+ results = translator .predict (es )
68+ print (en )
5669 print (results )
5770 print ()
5871 # val_dataset.append([pt.numpy().decode('utf-8'), en.numpy().decode('utf-8')])
0 commit comments