File tree Expand file tree Collapse file tree 3 files changed +7
-6
lines changed
tensorflow_asr/models/keras Expand file tree Collapse file tree 3 files changed +7
-6
lines changed Original file line number Diff line number Diff line change @@ -58,8 +58,9 @@ model_config:
5858 prediction_rnn_type : lstm
5959 prediction_rnn_implementation : 2
6060 prediction_layer_norm : False
61- prediction_projection_units : 0
61+ prediction_projection_units : 144
6262 joint_dim : 640
63+ prejoint_linear : False
6364 joint_activation : tanh
6465
6566learning_config :
@@ -112,7 +113,7 @@ learning_config:
112113 running_config :
113114 batch_size : 2
114115 accumulation_steps : 4
115- num_epochs : 20
116+ num_epochs : 50
116117 outdir : /mnt/Miscellanea/Models/local/conformer
117118 log_interval_steps : 300
118119 eval_interval_steps : 500
Original file line number Diff line number Diff line change @@ -47,7 +47,7 @@ def train_step(self, batch):
4747 scaled_gradients = tape .gradient (scaled_loss , self .trainable_weights )
4848 gradients = self .optimizer .get_unscaled_gradients (scaled_gradients )
4949 self .optimizer .apply_gradients (zip (gradients , self .trainable_variables ))
50- return {"train_ctc_loss " : loss }
50+ return {"ctc_loss " : loss }
5151
5252 def test_step (self , batch ):
5353 x , y_true = batch
@@ -57,4 +57,4 @@ def test_step(self, batch):
5757 'logit_length' : get_reduced_length (x ['input_length' ], self .time_reduction_factor )
5858 }
5959 loss = self .loss (y_true , y_pred )
60- return {"val_ctc_loss " : loss }
60+ return {"ctc_loss " : loss }
Original file line number Diff line number Diff line change @@ -73,7 +73,7 @@ def train_step(self, batch):
7373 scaled_gradients = tape .gradient (scaled_loss , self .trainable_weights )
7474 gradients = self .optimizer .get_unscaled_gradients (scaled_gradients )
7575 self .optimizer .apply_gradients (zip (gradients , self .trainable_variables ))
76- return {"train_rnnt_loss " : loss }
76+ return {"rnnt_loss " : loss }
7777
7878 def test_step (self , batch ):
7979 x , y_true = batch
@@ -84,4 +84,4 @@ def test_step(self, batch):
8484 "prediction_length" : x ["prediction_length" ],
8585 }, training = False )
8686 loss = self .loss (y_true , y_pred )
87- return {"val_rnnt_loss " : loss }
87+ return {"rnnt_loss " : loss }
You can’t perform that action at this time.
0 commit comments