Skip to content

Commit feea7e6

Browse files
committed
✍️ update keras return dict
1 parent fd024a3 commit feea7e6

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

examples/conformer/config.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,9 @@ model_config:
5858
prediction_rnn_type: lstm
5959
prediction_rnn_implementation: 2
6060
prediction_layer_norm: False
61-
prediction_projection_units: 0
61+
prediction_projection_units: 144
6262
joint_dim: 640
63+
prejoint_linear: False
6364
joint_activation: tanh
6465

6566
learning_config:
@@ -112,7 +113,7 @@ learning_config:
112113
running_config:
113114
batch_size: 2
114115
accumulation_steps: 4
115-
num_epochs: 20
116+
num_epochs: 50
116117
outdir: /mnt/Miscellanea/Models/local/conformer
117118
log_interval_steps: 300
118119
eval_interval_steps: 500

tensorflow_asr/models/keras/ctc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def train_step(self, batch):
4747
scaled_gradients = tape.gradient(scaled_loss, self.trainable_weights)
4848
gradients = self.optimizer.get_unscaled_gradients(scaled_gradients)
4949
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
50-
return {"train_ctc_loss": loss}
50+
return {"ctc_loss": loss}
5151

5252
def test_step(self, batch):
5353
x, y_true = batch
@@ -57,4 +57,4 @@ def test_step(self, batch):
5757
'logit_length': get_reduced_length(x['input_length'], self.time_reduction_factor)
5858
}
5959
loss = self.loss(y_true, y_pred)
60-
return {"val_ctc_loss": loss}
60+
return {"ctc_loss": loss}

tensorflow_asr/models/keras/transducer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def train_step(self, batch):
7373
scaled_gradients = tape.gradient(scaled_loss, self.trainable_weights)
7474
gradients = self.optimizer.get_unscaled_gradients(scaled_gradients)
7575
self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
76-
return {"train_rnnt_loss": loss}
76+
return {"rnnt_loss": loss}
7777

7878
def test_step(self, batch):
7979
x, y_true = batch
@@ -84,4 +84,4 @@ def test_step(self, batch):
8484
"prediction_length": x["prediction_length"],
8585
}, training=False)
8686
loss = self.loss(y_true, y_pred)
87-
return {"val_rnnt_loss": loss}
87+
return {"rnnt_loss": loss}

0 commit comments

Comments
 (0)