1717import tensorflow as tf
1818
1919from . import Model
20- from ..utils .utils import get_rnn , shape_list , count_non_blank
20+ from ..utils .utils import get_rnn , shape_list , count_non_blank , pad_prediction_tfarray
2121from ..featurizers .speech_featurizers import SpeechFeaturizer
2222from ..featurizers .text_featurizers import TextFeaturizer
2323from .layers .embedding import Embedding
@@ -400,42 +400,40 @@ def _perform_greedy_batch(self,
400400 encoded : tf .Tensor ,
401401 encoded_length : tf .Tensor ,
402402 parallel_iterations : int = 10 ,
403- swap_memory : bool = False ):
404- total_batch , total_time , _ = shape_list (encoded )
405- batch = tf .constant (0 , dtype = tf .int32 )
403+ swap_memory : bool = False ,
404+ version : str = 'v1' ):
405+ with tf .name_scope (f"{ self .name } _perform_greedy_batch" ):
406+ total_batch = tf .shape (encoded )[0 ]
407+ batch = tf .constant (0 , dtype = tf .int32 )
406408
407- decoded = tf .TensorArray (
408- dtype = tf .int32 , size = total_batch , dynamic_size = False ,
409- clear_after_read = False , element_shape = None
410- )
411-
412- def condition (batch , _ ): return tf .less (batch , total_batch )
409+ greedy_fn = self ._perform_greedy if version == 'v1' else self ._perform_greedy_v2
413410
414- def body (batch , decoded ):
415- hypothesis = self ._perform_greedy (
416- encoded = encoded [batch ],
417- encoded_length = encoded_length [batch ],
418- predicted = tf .constant (self .text_featurizer .blank , dtype = tf .int32 ),
419- states = self .predict_net .get_initial_state (),
420- parallel_iterations = parallel_iterations ,
421- swap_memory = swap_memory
411+ decoded = tf .TensorArray (
412+ dtype = tf .int32 , size = total_batch , dynamic_size = False ,
413+ clear_after_read = False , element_shape = tf .TensorShape ([None ])
422414 )
423- prediction = tf .pad (
424- hypothesis .prediction ,
425- paddings = [[0 , 2 * (total_time - encoded_length [batch ])]],
426- mode = "CONSTANT" , constant_values = self .text_featurizer .blank
415+
416+ def condition (batch , _ ): return tf .less (batch , total_batch )
417+
418+ def body (batch , decoded ):
419+ hypothesis = greedy_fn (
420+ encoded = encoded [batch ],
421+ encoded_length = encoded_length [batch ],
422+ predicted = tf .constant (self .text_featurizer .blank , dtype = tf .int32 ),
423+ states = self .predict_net .get_initial_state (),
424+ parallel_iterations = parallel_iterations ,
425+ swap_memory = swap_memory
426+ )
427+ decoded = decoded .write (batch , hypothesis .prediction )
428+ return batch + 1 , decoded
429+
430+ batch , decoded = tf .while_loop (
431+ condition , body , loop_vars = [batch , decoded ],
432+ parallel_iterations = parallel_iterations , swap_memory = True ,
427433 )
428- decoded = decoded .write (batch , prediction )
429- return batch + 1 , decoded
430-
431- batch , decoded = tf .while_loop (
432- condition , body ,
433- loop_vars = [batch , decoded ],
434- parallel_iterations = parallel_iterations ,
435- swap_memory = True ,
436- )
437434
438- return self .text_featurizer .iextract (decoded .stack ())
435+ decoded = pad_prediction_tfarray (decoded , blank = self .text_featurizer .blank )
436+ return self .text_featurizer .iextract (decoded .stack ())
439437
440438 def _perform_greedy (self ,
441439 encoded : tf .Tensor ,
@@ -457,12 +455,12 @@ def _perform_greedy(self,
457455 states = states
458456 )
459457
460- def condition (_time , _total , _encoded , _hypothesis ): return tf .less (_time , _total )
458+ def condition (_time , _hypothesis ): return tf .less (_time , total )
461459
462- def body (_time , _total , _encoded , _hypothesis ):
460+ def body (_time , _hypothesis ):
463461 ytu , _states = self .decoder_inference (
464462 # avoid using [index] in tflite
465- encoded = tf .gather_nd (_encoded , tf .reshape (_time , shape = [1 ])),
463+ encoded = tf .gather_nd (encoded , tf .reshape (_time , shape = [1 ])),
466464 predicted = _hypothesis .index ,
467465 states = _hypothesis .states
468466 )
@@ -480,13 +478,11 @@ def body(_time, _total, _encoded, _hypothesis):
480478 _prediction = _hypothesis .prediction .write (_time , _predict )
481479 _hypothesis = Hypothesis (index = _index , prediction = _prediction , states = _states )
482480
483- return _time + 1 , _total , _encoded , _hypothesis
481+ return _time + 1 , _hypothesis
484482
485- _ , _ , _ , hypothesis = tf .while_loop (
483+ time , hypothesis = tf .while_loop (
486484 condition , body ,
487- loop_vars = [time , total , encoded , hypothesis ],
488- parallel_iterations = parallel_iterations ,
489- swap_memory = swap_memory
485+ loop_vars = [time , hypothesis ], parallel_iterations = parallel_iterations , swap_memory = swap_memory
490486 )
491487
492488 return Hypothesis (index = hypothesis .index , prediction = hypothesis .prediction .stack (), states = hypothesis .states )
@@ -512,12 +508,12 @@ def _perform_greedy_v2(self,
512508 states = states
513509 )
514510
515- def condition (_time , _total , _encoded , _hypothesis ): return tf .less (_time , _total )
511+ def condition (_time , _hypothesis ): return tf .less (_time , total )
516512
517- def body (_time , _total , _encoded , _hypothesis ):
513+ def body (_time , _hypothesis ):
518514 ytu , _states = self .decoder_inference (
519515 # avoid using [index] in tflite
520- encoded = tf .gather_nd (_encoded , tf .reshape (_time , shape = [1 ])),
516+ encoded = tf .gather_nd (encoded , tf .reshape (_time , shape = [1 ])),
521517 predicted = _hypothesis .index ,
522518 states = _hypothesis .states
523519 )
@@ -531,13 +527,11 @@ def body(_time, _total, _encoded, _hypothesis):
531527 _prediction = _hypothesis .prediction .write (_time , _predict )
532528 _hypothesis = Hypothesis (index = _index , prediction = _prediction , states = _states )
533529
534- return _time , _total , _encoded , _hypothesis
530+ return _time , _hypothesis
535531
536- _ , _ , _ , hypothesis = tf .while_loop (
532+ time , hypothesis = tf .while_loop (
537533 condition , body ,
538- loop_vars = [time , total , encoded , hypothesis ],
539- parallel_iterations = parallel_iterations ,
540- swap_memory = swap_memory
534+ loop_vars = [time , hypothesis ], parallel_iterations = parallel_iterations , swap_memory = swap_memory
541535 )
542536
543537 return Hypothesis (index = hypothesis .index , prediction = hypothesis .prediction .stack (), states = hypothesis .states )
@@ -570,37 +564,32 @@ def _perform_beam_search_batch(self,
570564 lm : bool = False ,
571565 parallel_iterations : int = 10 ,
572566 swap_memory : bool = False ):
573- total_batch , total_time , _ = shape_list (encoded )
574- batch = tf .constant (0 , dtype = tf .int32 )
567+ with tf .name_scope (f"{ self .name } _perform_beam_search_batch" ):
568+ total_batch = tf .shape (encoded )[0 ]
569+ batch = tf .constant (0 , dtype = tf .int32 )
575570
576- decoded = tf .TensorArray (
577- dtype = tf .int32 , size = total_batch , dynamic_size = False ,
578- clear_after_read = False , element_shape = None
579- )
571+ decoded = tf .TensorArray (
572+ dtype = tf .int32 , size = total_batch , dynamic_size = False ,
573+ clear_after_read = False , element_shape = None
574+ )
580575
581- def condition (batch , _ ): return tf .less (batch , total_batch )
576+ def condition (batch , _ ): return tf .less (batch , total_batch )
582577
583- def body (batch , decoded ):
584- hypothesis = self ._perform_beam_search (
585- encoded [batch ], encoded_length [batch ], lm ,
586- parallel_iterations = parallel_iterations , swap_memory = swap_memory
587- )
588- prediction = tf .pad (
589- hypothesis .prediction ,
590- paddings = [[0 , 2 * (total_time - encoded_length [batch ])]],
591- mode = "CONSTANT" , constant_values = self .text_featurizer .blank
578+ def body (batch , decoded ):
579+ hypothesis = self ._perform_beam_search (
580+ encoded [batch ], encoded_length [batch ], lm ,
581+ parallel_iterations = parallel_iterations , swap_memory = swap_memory
582+ )
583+ decoded = decoded .write (batch , hypothesis .prediction )
584+ return batch + 1 , decoded
585+
586+ batch , decoded = tf .while_loop (
587+ condition , body , loop_vars = [batch , decoded ],
588+ parallel_iterations = parallel_iterations , swap_memory = True ,
592589 )
593- decoded = decoded .write (batch , prediction )
594- return batch + 1 , decoded
595-
596- batch , decoded = tf .while_loop (
597- condition , body ,
598- loop_vars = [batch , decoded ],
599- parallel_iterations = parallel_iterations ,
600- swap_memory = True ,
601- )
602590
603- return self .text_featurizer .iextract (decoded .stack ())
591+ decoded = pad_prediction_tfarray (decoded , blank = self .text_featurizer .blank )
592+ return self .text_featurizer .iextract (decoded .stack ())
604593
605594 def _perform_beam_search (self ,
606595 encoded : tf .Tensor ,
@@ -640,7 +629,7 @@ def initialize_beam(dynamic=False):
640629 B = BeamHypothesis (
641630 score = B .score .write (0 , 0.0 ),
642631 indices = B .indices .write (0 , self .text_featurizer .blank ),
643- prediction = B .prediction .write (0 , tf .ones ([total * 2 ], dtype = tf .int32 ) * self .text_featurizer .blank ),
632+ prediction = B .prediction .write (0 , tf .ones ([total ], dtype = tf .int32 ) * self .text_featurizer .blank ),
644633 states = B .states .write (0 , self .predict_net .get_initial_state ())
645634 )
646635
@@ -651,7 +640,8 @@ def body(time, total, B):
651640 A = BeamHypothesis (
652641 score = A .score .unstack (B .score .stack ()),
653642 indices = A .indices .unstack (B .indices .stack ()),
654- prediction = A .prediction .unstack (B .prediction .stack ()),
643+ prediction = A .prediction .unstack (
644+ pad_prediction_tfarray (B .prediction , blank = self .text_featurizer .blank ).stack ()),
655645 states = A .states .unstack (B .states .stack ()),
656646 )
657647 A_i = tf .constant (0 , tf .int32 )
@@ -666,7 +656,8 @@ def beam_body(beam, beam_width, A, A_i, B):
666656 y_hat_score , y_hat_score_index = tf .math .top_k (A .score .stack (), k = 1 , sorted = True )
667657 y_hat_score = y_hat_score [0 ]
668658 y_hat_index = tf .gather_nd (A .indices .stack (), y_hat_score_index )
669- y_hat_prediction = tf .gather_nd (A .prediction .stack (), y_hat_score_index )
659+ y_hat_prediction = tf .gather_nd (
660+ pad_prediction_tfarray (A .prediction , blank = self .text_featurizer .blank ).stack (), y_hat_score_index )
670661 y_hat_states = tf .gather_nd (A .states .stack (), y_hat_score_index )
671662
672663 # remove y_hat from A
@@ -676,7 +667,8 @@ def beam_body(beam, beam_width, A, A_i, B):
676667 A = BeamHypothesis (
677668 score = A .score .unstack (tf .gather_nd (A .score .stack (), remain_indices )),
678669 indices = A .indices .unstack (tf .gather_nd (A .indices .stack (), remain_indices )),
679- prediction = A .prediction .unstack (tf .gather_nd (A .prediction .stack (), remain_indices )),
670+ prediction = A .prediction .unstack (tf .gather_nd (
671+ pad_prediction_tfarray (A .prediction , blank = self .text_featurizer .blank ).stack (), remain_indices )),
680672 states = A .states .unstack (tf .gather_nd (A .states .stack (), remain_indices )),
681673 )
682674 A_i = tf .cond (tf .equal (A_i , 0 ), true_fn = lambda : A_i , false_fn = lambda : A_i - 1 )
@@ -752,14 +744,15 @@ def false_fn():
752744 )
753745
754746 scores = B .score .stack ()
747+ prediction = pad_prediction_tfarray (B .prediction , blank = self .text_featurizer .blank ).stack ()
755748 if self .text_featurizer .decoder_config .norm_score :
756- prediction_lengths = count_non_blank (B . prediction . stack () , blank = self .text_featurizer .blank , axis = 1 )
749+ prediction_lengths = count_non_blank (prediction , blank = self .text_featurizer .blank , axis = 1 )
757750 scores /= tf .cast (prediction_lengths , dtype = scores .dtype )
758751
759752 y_hat_score , y_hat_score_index = tf .math .top_k (scores , k = 1 )
760753 y_hat_score = y_hat_score [0 ]
761754 y_hat_index = tf .gather_nd (B .indices .stack (), y_hat_score_index )
762- y_hat_prediction = tf .gather_nd (B . prediction . stack () , y_hat_score_index )
755+ y_hat_prediction = tf .gather_nd (prediction , y_hat_score_index )
763756 y_hat_states = tf .gather_nd (B .states .stack (), y_hat_score_index )
764757
765758 return Hypothesis (index = y_hat_index , prediction = y_hat_prediction , states = y_hat_states )
0 commit comments