@@ -462,7 +462,7 @@ def execute(signal: tf.Tensor):
462462 return tf .map_fn (execute , signals , fn_output_signature = tf .TensorSpec ([], dtype = tf .string ))
463463
464464 def perform_beam_search (self , encoded , lm = False ):
465- with tf .name_scope (f"{ self .name } _beam_search" ):
465+ with tf .device ( "/CPU:0" ), tf . name_scope (f"{ self .name } _beam_search" ):
466466 beam_width = tf .cond (
467467 tf .less (self .text_featurizer .decoder_config .beam_width , self .text_featurizer .num_classes ),
468468 true_fn = lambda : self .text_featurizer .decoder_config .beam_width ,
@@ -520,9 +520,9 @@ def beam_condition(beam, beam_width, A, A_i, B): return tf.less(beam, beam_width
520520 def beam_body (beam , beam_width , A , A_i , B ):
521521 y_hat_score , y_hat_score_index = tf .math .top_k (A .score .stack (), k = 1 )
522522 y_hat_score = y_hat_score [0 ]
523- y_hat_index = tf .gather_nd (A .indices .stack (), tf . expand_dims ( y_hat_score_index [ 0 ], axis = - 1 ) )
524- y_hat_prediction = tf .gather_nd (A .prediction .stack (), tf . expand_dims ( y_hat_score_index [ 0 ], axis = - 1 ) )
525- y_hat_states = tf .gather_nd (A .states .stack (), tf . expand_dims ( y_hat_score_index [ 0 ], axis = - 1 ) )
523+ y_hat_index = tf .gather_nd (A .indices .stack (), y_hat_score_index )
524+ y_hat_prediction = tf .gather_nd (A .prediction .stack (), y_hat_score_index )
525+ y_hat_states = tf .gather_nd (A .states .stack (), y_hat_score_index )
526526
527527 ytu , new_states = self .decoder_inference (encoded = encoded_t , predicted = y_hat_index , states = y_hat_states )
528528
@@ -571,11 +571,16 @@ def predict_body(pred, A, A_i, B):
571571
572572 _ , _ , B = tf .while_loop (condition , body , loop_vars = (0 , total , B ))
573573
574- y_hat_score , y_hat_score_index = tf .math .top_k (B .score .stack (), k = 1 )
574+ scores = B .score .stack ()
575+ if self .text_featurizer .decoder_config .norm_score :
576+ prediction_lengths = tf .strings .length (B .prediction .stack (), unit = "UTF8_CHAR" )
577+ scores /= tf .cast (prediction_lengths , dtype = scores .dtype )
578+
579+ y_hat_score , y_hat_score_index = tf .math .top_k (scores , k = 1 )
575580 y_hat_score = y_hat_score [0 ]
576- y_hat_index = tf .gather_nd (B .indices .stack (), tf . expand_dims ( y_hat_score_index [ 0 ], axis = - 1 ) )
577- y_hat_prediction = tf .gather_nd (B .prediction .stack (), tf . expand_dims ( y_hat_score_index [ 0 ], axis = - 1 ) )
578- y_hat_states = tf .gather_nd (B .states .stack (), tf . expand_dims ( y_hat_score_index [ 0 ], axis = - 1 ) )
581+ y_hat_index = tf .gather_nd (B .indices .stack (), y_hat_score_index )
582+ y_hat_prediction = tf .gather_nd (B .prediction .stack (), y_hat_score_index )
583+ y_hat_states = tf .gather_nd (B .states .stack (), y_hat_score_index )
579584
580585 return Hypothesis (
581586 index = y_hat_index ,
0 commit comments