@@ -90,48 +90,39 @@ def compile(self,
9090class TransducerTrainerGA (TransducerTrainer ):
9191 """ Transducer Trainer that uses Gradients Accumulation """
9292
93- @tf .function (experimental_relax_shapes = True )
94- def _train_step (self , batch ):
95- _ , bfeatures , binput_length , blabels , blabel_length , bpred_inp = batch
96-
93+ @tf .function
94+ def _train_function (self , iterator ):
95+ for _ in range (self .config .accumulation_steps ):
96+ batch = next (iterator )
97+ self .strategy .run (self ._train_step , args = (batch ,))
98+ self .strategy .run (self ._apply_gradients , args = ())
99+
100+ @tf .function
101+ def _apply_gradients (self ):
102+ self .optimizer .apply_gradients (
103+ zip (self .accumulation .gradients , self .model .trainable_variables ))
97104 self .accumulation .reset ()
98105
99- for accum_step in range (self .config .accumulation_steps ):
106+ @tf .function (experimental_relax_shapes = True )
107+ def _train_step (self , batch ):
108+ _ , features , input_length , labels , label_length , pred_inp = batch
100109
101- indices = tf .expand_dims (
102- tf .range (
103- accum_step * self .accumulation_bs ,
104- (accum_step + 1 ) * self .accumulation_bs ,
105- dtype = tf .int32
106- ),
107- axis = - 1
110+ with tf .GradientTape () as tape :
111+ logits = self .model ([features , pred_inp ], training = True )
112+ tape .watch (logits )
113+ per_train_loss = rnnt_loss (
114+ logits = logits , labels = labels , label_length = label_length ,
115+ logit_length = (input_length // self .model .time_reduction_factor ),
116+ blank = self .text_featurizer .blank
117+ )
118+ train_loss = tf .nn .compute_average_loss (
119+ per_train_loss ,
120+ global_batch_size = self .global_batch_size
108121 )
109122
110- features = tf .gather_nd (bfeatures , indices )
111- input_length = tf .gather_nd (binput_length , indices )
112- labels = tf .gather_nd (blabels , indices )
113- label_length = tf .gather_nd (blabel_length , indices )
114- pred_inp = tf .gather_nd (bpred_inp , indices )
115-
116- with tf .GradientTape () as tape :
117- logits = self .model ([features , pred_inp ], training = True )
118- tape .watch (logits )
119- per_train_loss = rnnt_loss (
120- logits = logits , labels = labels , label_length = label_length ,
121- logit_length = (input_length // self .model .time_reduction_factor ),
122- blank = self .text_featurizer .blank
123- )
124- train_loss = tf .nn .compute_average_loss (
125- per_train_loss ,
126- global_batch_size = self .global_batch_size
127- )
128-
129- step_gradients = tape .gradient (train_loss , self .model .trainable_variables )
130- self .accumulation .accumulate (step_gradients )
131- self .train_metrics ["transducer_loss" ].update_state (per_train_loss )
132-
133- self .optimizer .apply_gradients (
134- zip (self .accumulation .gradients , self .model .trainable_variables ))
123+ gradients = tape .gradient (train_loss , self .model .trainable_variables )
124+ self .accumulation .accumulate (gradients )
125+ self .train_metrics ["transducer_loss" ].update_state (per_train_loss )
135126
136127 def compile (self ,
137128 model : Transducer ,
0 commit comments