Skip to content

Commit 96b482d

Browse files
committed
✍️ update transducer prediction and scripts
1 parent 8a005d2 commit 96b482d

File tree

3 files changed

+3
-4
lines changed

3 files changed

+3
-4
lines changed

examples/conformer/train_tpu_keras_subword_conformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
parser.add_argument("--bs", type=int, default=None, help="Batch size per replica")
3434

35-
parser.add_argument("--spx", type=int, default=50, help="Steps per execution for maximizing TPU performance")
35+
parser.add_argument("--spx", type=int, default=1, help="Steps per execution for maximizing TPU performance")
3636

3737
parser.add_argument("--tpu_address", type=str, default=None, help="TPU address. Leave None on Colab")
3838

examples/contextnet/train_tpu_keras_subword_contextnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
parser.add_argument("--bs", type=int, default=None, help="Batch size per replica")
3434

35-
parser.add_argument("--spx", type=int, default=50, help="Steps per execution for maximizing TPU performance")
35+
parser.add_argument("--spx", type=int, default=1, help="Steps per execution for maximizing TPU performance")
3636

3737
parser.add_argument("--tpu_address", type=str, default=None, help="TPU address. Leave None on Colab")
3838

tensorflow_asr/models/transducer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,10 @@ def call(self, inputs, training=False, **kwargs):
9393
# inputs has shape [B, U]
9494
# use tf.gather_nd instead of tf.gather for tflite conversion
9595
outputs, prediction_length = inputs
96-
if not hasattr(self, "max_length"): self.max_length = shape_list(outputs)[-1]
9796
outputs = self.embed(outputs, training=training)
9897
outputs = self.do(outputs, training=training)
9998
for rnn in self.rnns:
100-
mask = tf.sequence_mask(prediction_length, maxlen=self.max_length)
99+
mask = tf.sequence_mask(prediction_length, maxlen=tf.shape(outputs)[1])
101100
outputs = rnn["rnn"](outputs, training=training, mask=mask)
102101
outputs = outputs[0]
103102
if rnn["ln"] is not None:

0 commit comments

Comments
 (0)