Skip to content

Commit d8a130a

Browse files
committed
✍️ add option to use loss scale in keras compile
1 parent feea7e6 commit d8a130a

File tree

2 files changed

+24
-12
lines changed

2 files changed

+24
-12
lines changed

tensorflow_asr/models/keras/ctc.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@
2323
class CtcModel(BaseCtcModel):
2424
""" Keras CTC Model Warper """
2525

26-
def compile(self, optimizer, global_batch_size, blank=0,
26+
def compile(self, optimizer, global_batch_size, blank=0, use_loss_scale=False,
2727
loss_weights=None, weighted_metrics=None, run_eagerly=None, **kwargs):
2828
loss = CtcLoss(blank=blank, global_batch_size=global_batch_size)
29-
optimizer_with_scale = mxp.experimental.LossScaleOptimizer(tf.keras.optimizers.get(optimizer), 'dynamic')
29+
self.use_loss_scale = use_loss_scale
30+
if self.use_loss_scale:
31+
optimizer = mxp.experimental.LossScaleOptimizer(tf.keras.optimizers.get(optimizer), 'dynamic')
3032
super(CtcModel, self).compile(
31-
optimizer=optimizer_with_scale, loss=loss,
33+
optimizer=optimizer, loss=loss,
3234
loss_weights=loss_weights, weighted_metrics=weighted_metrics,
3335
run_eagerly=run_eagerly,
3436
**kwargs
@@ -43,9 +45,13 @@ def train_step(self, batch):
4345
'logit_length': get_reduced_length(x['input_length'], self.time_reduction_factor)
4446
}
4547
loss = self.loss(y_true, y_pred)
46-
scaled_loss = self.optimizer.get_scaled_loss(loss)
47-
scaled_gradients = tape.gradient(scaled_loss, self.trainable_weights)
48-
gradients = self.optimizer.get_unscaled_gradients(scaled_gradients)
48+
if self.use_loss_scale:
49+
scaled_loss = self.optimizer.get_scaled_loss(loss)
50+
if self.use_loss_scale:
51+
scaled_gradients = tape.gradient(scaled_loss, self.trainable_weights)
52+
gradients = self.optimizer.get_unscaled_gradients(scaled_gradients)
53+
else:
54+
gradients = tape.gradient(loss, self.trainable_weights)
4955
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
5056
return {"ctc_loss": loss}
5157

tensorflow_asr/models/keras/transducer.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,14 @@ def call(self, inputs, training=False, **kwargs):
4848
"logit_length": get_reduced_length(inputs["input_length"], self.time_reduction_factor)
4949
}
5050

51-
def compile(self, optimizer, global_batch_size, blank=0,
51+
def compile(self, optimizer, global_batch_size, blank=0, use_loss_scale=False,
5252
loss_weights=None, weighted_metrics=None, run_eagerly=None, **kwargs):
5353
loss = RnntLoss(blank=blank, global_batch_size=global_batch_size)
54-
optimizer_with_scale = mxp.experimental.LossScaleOptimizer(tf.keras.optimizers.get(optimizer), 'dynamic')
54+
self.use_loss_scale = use_loss_scale
55+
if self.use_loss_scale:
56+
optimizer = mxp.experimental.LossScaleOptimizer(tf.keras.optimizers.get(optimizer), 'dynamic')
5557
super(Transducer, self).compile(
56-
optimizer=optimizer_with_scale, loss=loss,
58+
optimizer=optimizer, loss=loss,
5759
loss_weights=loss_weights, weighted_metrics=weighted_metrics,
5860
run_eagerly=run_eagerly,
5961
**kwargs
@@ -69,9 +71,13 @@ def train_step(self, batch):
6971
"prediction_length": x["prediction_length"],
7072
}, training=True)
7173
loss = self.loss(y_true, y_pred)
72-
scaled_loss = self.optimizer.get_scaled_loss(loss)
73-
scaled_gradients = tape.gradient(scaled_loss, self.trainable_weights)
74-
gradients = self.optimizer.get_unscaled_gradients(scaled_gradients)
74+
if self.use_loss_scale:
75+
scaled_loss = self.optimizer.get_scaled_loss(loss)
76+
if self.use_loss_scale:
77+
scaled_gradients = tape.gradient(scaled_loss, self.trainable_weights)
78+
gradients = self.optimizer.get_unscaled_gradients(scaled_gradients)
79+
else:
80+
gradients = tape.gradient(loss, self.trainable_weights)
7581
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
7682
return {"rnnt_loss": loss}
7783

0 commit comments

Comments
 (0)