55except : pass
66
77from keras .callbacks import EarlyStopping , ModelCheckpoint , ReduceLROnPlateau , TensorBoard
8- from mltu .tensorflow .callbacks import Model2onnx
8+ from mltu .tensorflow .callbacks import Model2onnx , WarmupCosineDecay
99
1010from mltu .tensorflow .dataProvider import DataProvider
1111from mltu .tokenizers import CustomTokenizer
1212
1313from mltu .tensorflow .transformer .utils import MaskedAccuracy , MaskedLoss
1414from mltu .tensorflow .transformer .callbacks import EncDecSplitCallback
15- from mltu .tensorflow .schedules import CustomSchedule
1615
1716from model import Transformer
1817from configs import ModelConfigs
@@ -42,7 +41,7 @@ def read_files(path):
4241es_training_data , en_training_data = zip (* train_dataset )
4342es_validation_data , en_validation_data = zip (* val_dataset )
4443
45- # prepare portuguese tokenizer, this is the input language
44+ # prepare spanish tokenizer, this is the input language
4645tokenizer = CustomTokenizer (char_level = True )
4746tokenizer .fit_on_texts (es_training_data )
4847tokenizer .save (configs .model_path + "/tokenizer.json" )
@@ -99,17 +98,7 @@ def preprocess_inputs(data_batch, label_batch):
9998
10099transformer .summary ()
101100
102- # Define learning rate schedule
103- learning_rate = CustomSchedule (
104- steps_per_epoch = len (train_dataProvider ),
105- init_lr = configs .init_lr ,
106- lr_after_warmup = configs .lr_after_warmup ,
107- final_lr = configs .final_lr ,
108- warmup_epochs = configs .warmup_epochs ,
109- decay_epochs = configs .decay_epochs ,
110- )
111-
112- optimizer = tf .keras .optimizers .Adam (learning_rate = learning_rate , beta_1 = 0.9 , beta_2 = 0.98 , epsilon = 1e-9 )
101+ optimizer = tf .keras .optimizers .Adam (learning_rate = configs .init_lr , beta_1 = 0.9 , beta_2 = 0.98 , epsilon = 1e-9 )
113102
114103# Compile the model
115104transformer .compile (
@@ -120,6 +109,13 @@ def preprocess_inputs(data_batch, label_batch):
120109 )
121110
122111# Define callbacks
112+ warmupCosineDecay = WarmupCosineDecay (
113+ lr_after_warmup = configs .lr_after_warmup ,
114+ final_lr = configs .final_lr ,
115+ warmup_epochs = configs .warmup_epochs ,
116+ decay_epochs = configs .decay_epochs ,
117+ initial_lr = configs .init_lr ,
118+ )
123119earlystopper = EarlyStopping (monitor = "val_masked_accuracy" , patience = 5 , verbose = 1 , mode = "max" )
124120checkpoint = ModelCheckpoint (f"{ configs .model_path } /model.h5" , monitor = "val_masked_accuracy" , verbose = 1 , save_best_only = True , mode = "max" , save_weights_only = False )
125121tb_callback = TensorBoard (f"{ configs .model_path } /logs" )
0 commit comments