@@ -48,12 +48,14 @@ def call(self, inputs, training=False, **kwargs):
4848 "logit_length" : get_reduced_length (inputs ["input_length" ], self .time_reduction_factor )
4949 }
5050
51- def compile (self , optimizer , global_batch_size , blank = 0 ,
51+ def compile (self , optimizer , global_batch_size , blank = 0 , use_loss_scale = False ,
5252 loss_weights = None , weighted_metrics = None , run_eagerly = None , ** kwargs ):
5353 loss = RnntLoss (blank = blank , global_batch_size = global_batch_size )
54- optimizer_with_scale = mxp .experimental .LossScaleOptimizer (tf .keras .optimizers .get (optimizer ), 'dynamic' )
54+ self .use_loss_scale = use_loss_scale
55+ if self .use_loss_scale :
56+ optimizer = mxp .experimental .LossScaleOptimizer (tf .keras .optimizers .get (optimizer ), 'dynamic' )
5557 super (Transducer , self ).compile (
56- optimizer = optimizer_with_scale , loss = loss ,
58+ optimizer = optimizer , loss = loss ,
5759 loss_weights = loss_weights , weighted_metrics = weighted_metrics ,
5860 run_eagerly = run_eagerly ,
5961 ** kwargs
@@ -69,9 +71,13 @@ def train_step(self, batch):
6971 "prediction_length" : x ["prediction_length" ],
7072 }, training = True )
7173 loss = self .loss (y_true , y_pred )
72- scaled_loss = self .optimizer .get_scaled_loss (loss )
73- scaled_gradients = tape .gradient (scaled_loss , self .trainable_weights )
74- gradients = self .optimizer .get_unscaled_gradients (scaled_gradients )
74+ if self .use_loss_scale :
75+ scaled_loss = self .optimizer .get_scaled_loss (loss )
76+ if self .use_loss_scale :
77+ scaled_gradients = tape .gradient (scaled_loss , self .trainable_weights )
78+ gradients = self .optimizer .get_unscaled_gradients (scaled_gradients )
79+ else :
80+ gradients = tape .gradient (loss , self .trainable_weights )
7581 self .optimizer .apply_gradients (zip (gradients , self .trainable_variables ))
7682 return {"rnnt_loss" : loss }
7783
0 commit comments