Skip to content

Commit 8b9383c

Browse files
authored
Merge pull request #99 from TensorSpeech/fix/beamsearch
Fix transducer beam search for longer sequence
2 parents 4842078 + 3d5c7f3 commit 8b9383c

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

tensorflow_asr/models/transducer.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ def initialize_beam(dynamic=False):
591591
B = BeamHypothesis(
592592
score=B.score.write(0, 0.0),
593593
indices=B.indices.write(0, self.text_featurizer.blank),
594-
prediction=B.prediction.write(0, tf.ones([total], dtype=tf.int32) * self.text_featurizer.blank),
594+
prediction=B.prediction.write(0, tf.ones([total * 2], dtype=tf.int32) * self.text_featurizer.blank),
595595
states=B.states.write(0, self.predict_net.get_initial_state())
596596
)
597597

@@ -673,10 +673,7 @@ def false_fn():
673673

674674
b_score, b_indices, b_prediction, b_states, \
675675
a_score, a_indices, a_prediction, a_states, A_i = tf.cond(
676-
tf.equal(pred, self.text_featurizer.blank),
677-
true_fn=true_fn,
678-
false_fn=false_fn
679-
)
676+
tf.equal(pred, self.text_featurizer.blank), true_fn=true_fn, false_fn=false_fn)
680677

681678
B = BeamHypothesis(score=b_score, indices=b_indices, prediction=b_prediction, states=b_states)
682679
A = BeamHypothesis(score=a_score, indices=a_indices, prediction=a_prediction, states=a_states)

0 commit comments

Comments
 (0)