Skip to content

Commit 0149976

Browse files
committed
Updating transformer training code
1 parent cb8665f commit 0149976

File tree

9 files changed

+80
-102
lines changed

9 files changed

+80
-102
lines changed

Tutorials/05_sound_to_text/inferenceModel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def predict(self, data: np.ndarray):
3636
spectrogram = WavReader.get_spectrogram(wav_path, frame_length=configs.frame_length, frame_step=configs.frame_step, fft_length=configs.fft_length)
3737
# WavReader.plot_raw_audio(wav_path, label)
3838

39-
padded_spectrogram = np.pad(spectrogram, ((configs.max_spectrogram_length - spectrogram.shape[0], 0),(0,0)), mode="constant", constant_values=0)
39+
padded_spectrogram = np.pad(spectrogram, ((0, configs.max_spectrogram_length - spectrogram.shape[0]),(0,0)), mode="constant", constant_values=0)
4040

4141
# WavReader.plot_spectrogram(spectrogram, label)
4242

Tutorials/09_translation_transformer/configs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ def __init__(self):
1414
self.num_layers = 4
1515
self.d_model = 128
1616
self.num_heads = 8
17-
self.dff = 128
17+
self.dff = 512
1818
self.dropout_rate = 0.1
19-
self.batch_size = 32
20-
self.train_epochs = 20
19+
self.batch_size = 16
20+
self.train_epochs = 50
2121
# CustomSchedule parameters
2222
self.init_lr = 0.00001
2323
self.lr_after_warmup = 0.0005
2424
self.final_lr = 0.0001
2525
self.warmup_epochs = 2
26-
self.decay_epochs = 9
26+
self.decay_epochs = 18
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
beautifulsoup4

Tutorials/09_translation_transformer/train.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@
55
except: pass
66

77
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard
8-
from mltu.tensorflow.callbacks import Model2onnx
8+
from mltu.tensorflow.callbacks import Model2onnx, WarmupCosineDecay
99

1010
from mltu.tensorflow.dataProvider import DataProvider
1111
from mltu.tokenizers import CustomTokenizer
1212

1313
from mltu.tensorflow.transformer.utils import MaskedAccuracy, MaskedLoss
1414
from mltu.tensorflow.transformer.callbacks import EncDecSplitCallback
15-
from mltu.tensorflow.schedules import CustomSchedule
1615

1716
from model import Transformer
1817
from configs import ModelConfigs
@@ -42,7 +41,7 @@ def read_files(path):
4241
es_training_data, en_training_data = zip(*train_dataset)
4342
es_validation_data, en_validation_data = zip(*val_dataset)
4443

45-
# prepare portuguese tokenizer, this is the input language
44+
# prepare spanish tokenizer, this is the input language
4645
tokenizer = CustomTokenizer(char_level=True)
4746
tokenizer.fit_on_texts(es_training_data)
4847
tokenizer.save(configs.model_path + "/tokenizer.json")
@@ -99,17 +98,7 @@ def preprocess_inputs(data_batch, label_batch):
9998

10099
transformer.summary()
101100

102-
# Define learning rate schedule
103-
learning_rate = CustomSchedule(
104-
steps_per_epoch=len(train_dataProvider),
105-
init_lr=configs.init_lr,
106-
lr_after_warmup=configs.lr_after_warmup,
107-
final_lr=configs.final_lr,
108-
warmup_epochs=configs.warmup_epochs,
109-
decay_epochs=configs.decay_epochs,
110-
)
111-
112-
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
101+
optimizer = tf.keras.optimizers.Adam(learning_rate=configs.init_lr, beta_1=0.9, beta_2=0.98, epsilon=1e-9)
113102

114103
# Compile the model
115104
transformer.compile(
@@ -120,6 +109,13 @@ def preprocess_inputs(data_batch, label_batch):
120109
)
121110

122111
# Define callbacks
112+
warmupCosineDecay = WarmupCosineDecay(
113+
lr_after_warmup=configs.lr_after_warmup,
114+
final_lr=configs.final_lr,
115+
warmup_epochs=configs.warmup_epochs,
116+
decay_epochs=configs.decay_epochs,
117+
initial_lr=configs.init_lr,
118+
)
123119
earlystopper = EarlyStopping(monitor="val_masked_accuracy", patience=5, verbose=1, mode="max")
124120
checkpoint = ModelCheckpoint(f"{configs.model_path}/model.h5", monitor="val_masked_accuracy", verbose=1, save_best_only=True, mode="max", save_weights_only=False)
125121
tb_callback = TensorBoard(f"{configs.model_path}/logs")

mltu/tensorflow/callbacks.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,4 +101,48 @@ def __init__(self, log_path: str, log_file: str="logs.log", logLevel=logging.INF
101101
def on_epoch_end(self, epoch: int, logs: dict=None):
102102
epoch_message = f"Epoch {epoch}; "
103103
logs_message = "; ".join([f"{key}: {value}" for key, value in logs.items()])
104-
self.logger.info(epoch_message + logs_message)
104+
self.logger.info(epoch_message + logs_message)
105+
106+
107+
class WarmupCosineDecay(Callback):
108+
""" Cosine decay learning rate scheduler with warmup
109+
110+
Args:
111+
lr_after_warmup (float): Learning rate after warmup
112+
final_lr (float): Final learning rate
113+
warmup_epochs (int): Number of warmup epochs
114+
decay_epochs (int): Number of decay epochs
115+
initial_lr (float, optional): Initial learning rate. Defaults to 0.0.
116+
verbose (bool, optional): Whether to print learning rate. Defaults to False.
117+
"""
118+
def __init__(
119+
self,
120+
lr_after_warmup: float,
121+
final_lr: float,
122+
warmup_epochs: int,
123+
decay_epochs: int,
124+
initial_lr: float=0.0,
125+
verbose=False
126+
) -> None:
127+
super(WarmupCosineDecay, self).__init__()
128+
self.lr_after_warmup = lr_after_warmup
129+
self.final_lr = final_lr
130+
self.warmup_epochs = warmup_epochs
131+
self.decay_epochs = decay_epochs
132+
self.initial_lr = initial_lr
133+
self.verbose = verbose
134+
135+
def on_epoch_begin(self, epoch: int, logs: dict=None):
136+
""" Adjust learning rate at the beginning of each epoch """
137+
if epoch < self.warmup_epochs:
138+
lr = self.initial_lr + (self.lr_after_warmup - self.initial_lr) * (epoch + 1) / self.warmup_epochs
139+
elif epoch < self.warmup_epochs + self.decay_epochs:
140+
progress = (epoch - self.warmup_epochs) / self.decay_epochs
141+
lr = self.final_lr + 0.5 * (self.lr_after_warmup - self.final_lr) * (1 + tf.cos(tf.constant(progress) * 3.14159))
142+
else:
143+
return None # No change to learning rate
144+
145+
tf.keras.backend.set_value(self.model.optimizer.lr, lr)
146+
147+
if self.verbose:
148+
print(f"Epoch {epoch + 1} - Learning Rate: {lr}")

mltu/tensorflow/schedules.py

Lines changed: 0 additions & 80 deletions
This file was deleted.

mltu/tensorflow/transformer/layers.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,14 @@ def __init__(self, num_layers: int, d_model: int, num_heads: int, dff: int, voca
204204
for _ in range(num_layers)]
205205
self.dropout = tf.keras.layers.Dropout(dropout_rate)
206206

207+
def get_config(self):
208+
config = super().get_config()
209+
config.update({
210+
'd_model': self.d_model,
211+
'num_layers': self.num_layers,
212+
})
213+
return config
214+
207215
def call(self, x: tf.Tensor) -> tf.Tensor:
208216
"""
209217
The call function that performs the forward pass of the layer.
@@ -323,6 +331,14 @@ def __init__(self, num_layers: int, d_model: int, num_heads: int, dff: int, voca
323331

324332
self.last_attn_scores = None
325333

334+
def get_config(self):
335+
config = super().get_config()
336+
config.update({
337+
'd_model': self.d_model,
338+
'num_layers': self.num_layers,
339+
})
340+
return config
341+
326342
def call(self, x: tf.Tensor, context: tf.Tensor) -> tf.Tensor:
327343
"""
328344
The call function that performs the forward pass of the layer.

mltu/tokenizers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ def save(self, path: str, type: str="json"):
181181
"""
182182
serialised_dict = self.dict()
183183
if type == "json":
184-
os.makedirs(os.path.dirname(path), exist_ok=True)
184+
if os.path.dirname(path):
185+
os.makedirs(os.path.dirname(path), exist_ok=True)
185186
with open(path, "w") as f:
186187
json.dump(serialised_dict, f)
187188

mltu/transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def __init__(
159159
self.padding_value = padding_value
160160

161161
def __call__(self, spectrogram: np.ndarray, label: np.ndarray):
162-
padded_spectrogram = np.pad(spectrogram, ((self.max_spectrogram_length - spectrogram.shape[0], 0),(0,0)), mode="constant", constant_values=self.padding_value)
162+
padded_spectrogram = np.pad(spectrogram, (0, (self.max_spectrogram_length - spectrogram.shape[0]),(0,0)), mode="constant", constant_values=self.padding_value)
163163

164164
return padded_spectrogram, label
165165

0 commit comments

Comments
 (0)