@@ -417,73 +417,51 @@ def __perform_greedy(self,
417417 with tf .name_scope (f"{ self .name } _greedy" ):
418418 time = tf .constant (0 , dtype = tf .int32 )
419419 total = encoded_length
420- # Initialize prediction with a blank
421- # Prediction can not be longer than the encoded of audio plus blank
422- prediction = tf .TensorArray (
423- dtype = tf .int32 ,
424- size = (total + 1 ),
425- dynamic_size = False ,
426- element_shape = tf .TensorShape ([]),
427- clear_after_read = False
428- )
429420
430421 hypothesis = Hypothesis (
431422 index = tf .constant (0 , dtype = tf .int32 ),
432- prediction = prediction . write ( 0 , predicted ) ,
423+ prediction = tf . ones ([ total + 1 ], dtype = tf . int32 ) * self . text_featurizer . blank ,
433424 states = states
434425 )
435426
436427 def condition (time , total , encoded , hypothesis ): return tf .less (time , total )
437428
438429 def body (time , total , encoded , hypothesis ):
430+ predicted = tf .gather_nd (hypothesis .prediction , tf .expand_dims (hypothesis .index , axis = - 1 ))
431+
439432 ytu , new_states = self .decoder_inference (
440433 # avoid using [index] in tflite
441434 encoded = tf .gather_nd (encoded , tf .expand_dims (time , axis = - 1 )),
442- predicted = hypothesis . prediction . read ( hypothesis . index ) ,
435+ predicted = predicted ,
443436 states = hypothesis .states
444437 )
445- char = tf .argmax (ytu , axis = - 1 , output_type = tf .int32 ) # => argmax []
446-
447- index , char , new_states = tf .cond (
448- tf .equal (char , self .text_featurizer .blank ),
449- true_fn = lambda : (
450- hypothesis .index ,
451- hypothesis .prediction .read (hypothesis .index ),
452- hypothesis .states
453- ),
454- false_fn = lambda : (
455- hypothesis .index + 1 ,
456- char ,
457- new_states
458- )
438+ new_predicted = tf .argmax (ytu , axis = - 1 , output_type = tf .int32 ) # => argmax []
439+
440+ index , new_predicted , new_states = tf .cond (
441+ tf .equal (new_predicted , self .text_featurizer .blank ),
442+ true_fn = lambda : (hypothesis .index , predicted , hypothesis .states ),
443+ false_fn = lambda : (hypothesis .index + 1 , new_predicted , new_states )
459444 )
460445
461446 hypothesis = Hypothesis (
462447 index = index ,
463- prediction = hypothesis .prediction .write (index , char ),
448+ prediction = tf .tensor_scatter_nd_update (
449+ hypothesis .prediction ,
450+ indices = tf .reshape (index , [1 , 1 ]),
451+ updates = tf .expand_dims (new_predicted , axis = - 1 )
452+ ),
464453 states = new_states
465454 )
466455
467456 return time + 1 , total , encoded , hypothesis
468457
469458 time , total , encoded , hypothesis = tf .while_loop (
470- condition ,
471- body ,
459+ condition , body ,
472460 loop_vars = (time , total , encoded , hypothesis ),
473461 parallel_iterations = parallel_iterations ,
474462 swap_memory = swap_memory
475463 )
476464
477- # Gather predicted sequence
478- hypothesis = Hypothesis (
479- index = hypothesis .index ,
480- prediction = tf .gather_nd (
481- params = hypothesis .prediction .stack (),
482- indices = tf .expand_dims (tf .range (hypothesis .index + 1 ), axis = - 1 )
483- ),
484- states = hypothesis .states
485- )
486-
487465 return hypothesis
488466
489467 # -------------------------------- BEAM SEARCH -------------------------------------
0 commit comments