Skip to content

Commit 8e02e88

Browse files
committed
⚡ add max length to predict network
1 parent b10cb03 commit 8e02e88

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

tensorflow_asr/datasets/asr_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def update_lengths(self, max_lengths_prefix: str = None):
9090
# -------------------------------- ENTRIES -------------------------------------
9191

9292
def read_entries(self):
93-
if hasattr(self, 'entries') and len(self.entries) > 0: return
93+
if hasattr(self, "entries") and len(self.entries) > 0: return
9494
self.entries = []
9595
for file_path in self.data_paths:
9696
print(f"Reading {file_path} ...")

tensorflow_asr/models/transducer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,11 @@ def call(self, inputs, training=False, **kwargs):
9393
# inputs has shape [B, U]
9494
# use tf.gather_nd instead of tf.gather for tflite conversion
9595
outputs, prediction_length = inputs
96+
if not hasattr(self, "max_length"): self.max_length = shape_list(outputs)[-1]
9697
outputs = self.embed(outputs, training=training)
9798
outputs = self.do(outputs, training=training)
9899
for rnn in self.rnns:
99-
mask = tf.sequence_mask(prediction_length)
100+
mask = tf.sequence_mask(prediction_length, maxlen=self.max_length)
100101
outputs = rnn["rnn"](outputs, training=training, mask=mask)
101102
outputs = outputs[0]
102103
if rnn["ln"] is not None:

0 commit comments

Comments
 (0)