Skip to content

Commit 449fc95

Browse files
committed
Fixed learning rate logging from warmupCosineDecay callback
1 parent adb3830 commit 449fc95

File tree

3 files changed

+18
-32
lines changed

3 files changed

+18
-32
lines changed

Tutorials/09_translation_transformer/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,15 +72,15 @@ def preprocess_inputs(data_batch, label_batch):
7272
train_dataset,
7373
batch_size=configs.batch_size,
7474
batch_postprocessors=[preprocess_inputs],
75-
use_cache=True
75+
use_cache=True,
7676
)
7777

7878
# Create Validation Data Provider
7979
val_dataProvider = DataProvider(
8080
val_dataset,
8181
batch_size=configs.batch_size,
8282
batch_postprocessors=[preprocess_inputs],
83-
use_cache=True
83+
use_cache=True,
8484
)
8585

8686
# Create TensorFlow Transformer Model
@@ -129,6 +129,7 @@ def preprocess_inputs(data_batch, label_batch):
129129
validation_data=val_dataProvider,
130130
epochs=configs.train_epochs,
131131
callbacks=[
132+
warmupCosineDecay,
132133
checkpoint,
133134
tb_callback,
134135
reduceLROnPlat,

mltu/tensorflow/callbacks.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,15 +134,25 @@ def __init__(
134134

135135
def on_epoch_begin(self, epoch: int, logs: dict=None):
136136
""" Adjust learning rate at the beginning of each epoch """
137+
138+
if epoch >= self.warmup_epochs + self.decay_epochs:
139+
return logs
140+
137141
if epoch < self.warmup_epochs:
138142
lr = self.initial_lr + (self.lr_after_warmup - self.initial_lr) * (epoch + 1) / self.warmup_epochs
139-
elif epoch < self.warmup_epochs + self.decay_epochs:
143+
else:
140144
progress = (epoch - self.warmup_epochs) / self.decay_epochs
141145
lr = self.final_lr + 0.5 * (self.lr_after_warmup - self.final_lr) * (1 + tf.cos(tf.constant(progress) * 3.14159))
142-
else:
143-
return None # No change to learning rate
144146

145147
tf.keras.backend.set_value(self.model.optimizer.lr, lr)
146148

147149
if self.verbose:
148-
print(f"Epoch {epoch + 1} - Learning Rate: {lr}")
150+
print(f"Epoch {epoch + 1} - Learning Rate: {lr}")
151+
152+
def on_epoch_end(self, epoch: int, logs: dict=None):
153+
logs = logs or {}
154+
155+
# Log the learning rate value
156+
logs["lr"] = self.model.optimizer.lr
157+
158+
return logs

mltu/tensorflow/transformer/utils.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -76,29 +76,4 @@ def result(self) -> tf.Tensor:
7676
Returns:
7777
tf.Tensor: Masked accuracy.
7878
"""
79-
return self.total / self.count
80-
81-
82-
# def masked_accuracy(y_true: tf.Tensor, y_pred: tf.Tensor):
83-
# """ Calculate masked accuracy.
84-
85-
# Args:
86-
# y_true (tf.Tensor): True labels.
87-
# y_pred (tf.Tensor): Predicted labels.
88-
89-
# Returns:
90-
# tf.Tensor: Masked accuracy.
91-
# """
92-
# pred = tf.argmax(y_pred, axis=2)
93-
# label = tf.cast(y_true, pred.dtype)
94-
# match = label == pred
95-
96-
# mask = label != 0
97-
98-
# match = match & mask
99-
100-
# match = tf.cast(match, dtype=tf.float32)
101-
# mask = tf.cast(mask, dtype=tf.float32)
102-
# accuracy = tf.reduce_sum(match) / tf.reduce_sum(mask)
103-
104-
# return accuracy
79+
return self.total / self.count

0 commit comments

Comments
 (0)