File tree Expand file tree Collapse file tree 3 files changed +18
-32
lines changed
Tutorials/09_translation_transformer Expand file tree Collapse file tree 3 files changed +18
-32
lines changed Original file line number Diff line number Diff line change @@ -72,15 +72,15 @@ def preprocess_inputs(data_batch, label_batch):
7272 train_dataset ,
7373 batch_size = configs .batch_size ,
7474 batch_postprocessors = [preprocess_inputs ],
75- use_cache = True
75+ use_cache = True ,
7676 )
7777
7878# Create Validation Data Provider
7979val_dataProvider = DataProvider (
8080 val_dataset ,
8181 batch_size = configs .batch_size ,
8282 batch_postprocessors = [preprocess_inputs ],
83- use_cache = True
83+ use_cache = True ,
8484 )
8585
8686# Create TensorFlow Transformer Model
@@ -129,6 +129,7 @@ def preprocess_inputs(data_batch, label_batch):
129129 validation_data = val_dataProvider ,
130130 epochs = configs .train_epochs ,
131131 callbacks = [
132+ warmupCosineDecay ,
132133 checkpoint ,
133134 tb_callback ,
134135 reduceLROnPlat ,
Original file line number Diff line number Diff line change @@ -134,15 +134,25 @@ def __init__(
134134
135135 def on_epoch_begin (self , epoch : int , logs : dict = None ):
136136 """ Adjust learning rate at the beginning of each epoch """
137+
138+ if epoch >= self .warmup_epochs + self .decay_epochs :
139+ return logs
140+
137141 if epoch < self .warmup_epochs :
138142 lr = self .initial_lr + (self .lr_after_warmup - self .initial_lr ) * (epoch + 1 ) / self .warmup_epochs
139- elif epoch < self . warmup_epochs + self . decay_epochs :
143+ else :
140144 progress = (epoch - self .warmup_epochs ) / self .decay_epochs
141145 lr = self .final_lr + 0.5 * (self .lr_after_warmup - self .final_lr ) * (1 + tf .cos (tf .constant (progress ) * 3.14159 ))
142- else :
143- return None # No change to learning rate
144146
145147 tf .keras .backend .set_value (self .model .optimizer .lr , lr )
146148
147149 if self .verbose :
148- print (f"Epoch { epoch + 1 } - Learning Rate: { lr } " )
150+ print (f"Epoch { epoch + 1 } - Learning Rate: { lr } " )
151+
152+ def on_epoch_end (self , epoch : int , logs : dict = None ):
153+ logs = logs or {}
154+
155+ # Log the learning rate value
156+ logs ["lr" ] = self .model .optimizer .lr
157+
158+ return logs
Original file line number Diff line number Diff line change @@ -76,29 +76,4 @@ def result(self) -> tf.Tensor:
7676 Returns:
7777 tf.Tensor: Masked accuracy.
7878 """
79- return self .total / self .count
80-
81-
82- # def masked_accuracy(y_true: tf.Tensor, y_pred: tf.Tensor):
83- # """ Calculate masked accuracy.
84-
85- # Args:
86- # y_true (tf.Tensor): True labels.
87- # y_pred (tf.Tensor): Predicted labels.
88-
89- # Returns:
90- # tf.Tensor: Masked accuracy.
91- # """
92- # pred = tf.argmax(y_pred, axis=2)
93- # label = tf.cast(y_true, pred.dtype)
94- # match = label == pred
95-
96- # mask = label != 0
97-
98- # match = match & mask
99-
100- # match = tf.cast(match, dtype=tf.float32)
101- # mask = tf.cast(mask, dtype=tf.float32)
102- # accuracy = tf.reduce_sum(match) / tf.reduce_sum(mask)
103-
104- # return accuracy
79+ return self .total / self .count
You can’t perform that action at this time.
0 commit comments