@@ -591,7 +591,7 @@ def initialize_beam(dynamic=False):
591591 B = BeamHypothesis (
592592 score = B .score .write (0 , 0.0 ),
593593 indices = B .indices .write (0 , self .text_featurizer .blank ),
594- prediction = B .prediction .write (0 , tf .ones ([total ], dtype = tf .int32 ) * self .text_featurizer .blank ),
594+ prediction = B .prediction .write (0 , tf .ones ([total * 2 ], dtype = tf .int32 ) * self .text_featurizer .blank ),
595595 states = B .states .write (0 , self .predict_net .get_initial_state ())
596596 )
597597
@@ -673,10 +673,7 @@ def false_fn():
673673
674674 b_score , b_indices , b_prediction , b_states , \
675675 a_score , a_indices , a_prediction , a_states , A_i = tf .cond (
676- tf .equal (pred , self .text_featurizer .blank ),
677- true_fn = true_fn ,
678- false_fn = false_fn
679- )
676+ tf .equal (pred , self .text_featurizer .blank ), true_fn = true_fn , false_fn = false_fn )
680677
681678 B = BeamHypothesis (score = b_score , indices = b_indices , prediction = b_prediction , states = b_states )
682679 A = BeamHypothesis (score = a_score , indices = a_indices , prediction = a_prediction , states = a_states )
0 commit comments