2323from ..featurizers .text_featurizers import TextFeaturizer
2424from .layers .embedding import Embedding
2525
26- Hypothesis = collections .namedtuple (
27- "Hypothesis" ,
28- ("index" , "prediction" , "states" )
29- )
26+ Hypothesis = collections .namedtuple ("Hypothesis" , ("index" , "prediction" , "states" ))
3027
31- BeamHypothesis = collections .namedtuple (
32- "BeamHypothesis" ,
33- ("score" , "indices" , "prediction" , "states" )
34- )
28+ BeamHypothesis = collections .namedtuple ("BeamHypothesis" , ("score" , "indices" , "prediction" , "states" ))
3529
3630
3731class TransducerPrediction (tf .keras .Model ):
@@ -233,6 +227,7 @@ def __init__(self,
233227 bias_regularizer = bias_regularizer ,
234228 name = f"{ name } _joint"
235229 )
230+ self .time_reduction_factor = 1
236231
237232 def _build (self , input_shape ):
238233 inputs = tf .keras .Input (shape = input_shape , dtype = tf .float32 )
@@ -369,6 +364,29 @@ def recognize_tflite(self, signal, predicted, states):
369364 hypothesis .states
370365 )
371366
367+ def recognize_tflite_with_timestamp (self , signal , predicted , states ):
368+ features = self .speech_featurizer .tf_extract (signal )
369+ encoded = self .encoder_inference (features )
370+ hypothesis = self ._perform_greedy (encoded , tf .shape (encoded )[0 ], predicted , states )
371+ indices = self .text_featurizer .normalize_indices (hypothesis .prediction )
372+ upoints = tf .gather_nd (self .text_featurizer .upoints , tf .expand_dims (indices , axis = - 1 )) # [None, max_subword_length]
373+
374+ num_samples = tf .cast (tf .shape (signal )[0 ], dtype = tf .float32 )
375+ total_time_reduction_factor = self .time_reduction_factor * self .speech_featurizer .frame_step
376+
377+ stime = tf .range (0 , num_samples , delta = total_time_reduction_factor , dtype = tf .float32 )
378+ stime /= tf .cast (self .speech_featurizer .sample_rate , dtype = tf .float32 )
379+
380+ etime = tf .range (total_time_reduction_factor , num_samples , delta = total_time_reduction_factor , dtype = tf .float32 )
381+ etime /= tf .cast (self .speech_featurizer .sample_rate , dtype = tf .float32 )
382+
383+ non_blank = tf .where (tf .not_equal (upoints , 0 ))
384+ non_blank_transcript = tf .gather_nd (upoints , non_blank )
385+ non_blank_stime = tf .gather_nd (tf .repeat (tf .expand_dims (stime , axis = - 1 ), tf .shape (upoints )[- 1 ], axis = - 1 ), non_blank )
386+ non_blank_etime = tf .gather_nd (tf .repeat (tf .expand_dims (etime , axis = - 1 ), tf .shape (upoints )[- 1 ], axis = - 1 ), non_blank )
387+
388+ return non_blank_transcript , non_blank_stime , non_blank_etime , hypothesis .prediction , hypothesis .states
389+
372390 def _perform_greedy_batch (self ,
373391 encoded : tf .Tensor ,
374392 encoded_length : tf .Tensor ,
@@ -400,7 +418,7 @@ def body(batch, total, encoded, encoded_length, decoded):
400418
401419 batch , total , _ , _ , decoded = tf .while_loop (
402420 condition , body ,
403- loop_vars = ( batch , total , encoded , encoded_length , decoded ) ,
421+ loop_vars = [ batch , total , encoded , encoded_length , decoded ] ,
404422 parallel_iterations = parallel_iterations ,
405423 swap_memory = True ,
406424 )
@@ -419,45 +437,43 @@ def _perform_greedy(self,
419437 total = encoded_length
420438
421439 hypothesis = Hypothesis (
422- index = tf .constant (0 , dtype = tf .int32 ),
423- prediction = tf .ones ([total + 1 ], dtype = tf .int32 ) * self .text_featurizer .blank ,
440+ index = tf .constant (self . text_featurizer . blank , dtype = tf .int32 ),
441+ prediction = tf .ones ([total ], dtype = tf .int32 ) * self .text_featurizer .blank ,
424442 states = states
425443 )
426444
427445 def condition (time , total , encoded , hypothesis ): return tf .less (time , total )
428446
429447 def body (time , total , encoded , hypothesis ):
430- predicted = tf .gather_nd (hypothesis .prediction , tf .expand_dims (hypothesis .index , axis = - 1 ))
431-
432- ytu , new_states = self .decoder_inference (
448+ ytu , states = self .decoder_inference (
433449 # avoid using [index] in tflite
434450 encoded = tf .gather_nd (encoded , tf .expand_dims (time , axis = - 1 )),
435- predicted = predicted ,
451+ predicted = hypothesis . index ,
436452 states = hypothesis .states
437453 )
438- new_predicted = tf .argmax (ytu , axis = - 1 , output_type = tf .int32 ) # => argmax []
454+ predict = tf .argmax (ytu , axis = - 1 , output_type = tf .int32 ) # => argmax []
439455
440- index , new_predicted , new_states = tf .cond (
441- tf .equal (new_predicted , self .text_featurizer .blank ),
442- true_fn = lambda : (hypothesis .index , predicted , hypothesis .states ),
443- false_fn = lambda : (hypothesis . index + 1 , new_predicted , new_states )
456+ index , predict , states = tf .cond (
457+ tf .equal (predict , self .text_featurizer .blank ),
458+ true_fn = lambda : (hypothesis .index , predict , hypothesis .states ),
459+ false_fn = lambda : (predict , predict , states ) # update if the new prediction is a non-blank
444460 )
445461
446462 hypothesis = Hypothesis (
447463 index = index ,
448464 prediction = tf .tensor_scatter_nd_update (
449465 hypothesis .prediction ,
450- indices = tf .reshape (index , [1 , 1 ]),
451- updates = tf .expand_dims (new_predicted , axis = - 1 )
466+ indices = tf .reshape (time , [1 , 1 ]),
467+ updates = tf .expand_dims (predict , axis = - 1 )
452468 ),
453- states = new_states
469+ states = states
454470 )
455471
456472 return time + 1 , total , encoded , hypothesis
457473
458474 time , total , encoded , hypothesis = tf .while_loop (
459475 condition , body ,
460- loop_vars = ( time , total , encoded , hypothesis ) ,
476+ loop_vars = [ time , total , encoded , hypothesis ] ,
461477 parallel_iterations = parallel_iterations ,
462478 swap_memory = swap_memory
463479 )
@@ -512,7 +528,7 @@ def body(batch, total, encoded, encoded_length, decoded):
512528
513529 batch , total , _ , _ , decoded = tf .while_loop (
514530 condition , body ,
515- loop_vars = ( batch , total , encoded , encoded_length , decoded ) ,
531+ loop_vars = [ batch , total , encoded , encoded_length , decoded ] ,
516532 parallel_iterations = parallel_iterations ,
517533 swap_memory = True ,
518534 )
@@ -626,23 +642,23 @@ def predict_body(pred, A, A_i, B):
626642
627643 _ , A , A_i , B = tf .while_loop (
628644 predict_condition , predict_body ,
629- loop_vars = ( 0 , A , A_i , B ) ,
645+ loop_vars = [ 0 , A , A_i , B ] ,
630646 parallel_iterations = parallel_iterations , swap_memory = swap_memory
631647 )
632648
633649 return beam + 1 , beam_width , A , A_i , B
634650
635651 _ , _ , A , A_i , B = tf .while_loop (
636652 beam_condition , beam_body ,
637- loop_vars = ( 0 , beam_width , A , A_i , B ) ,
653+ loop_vars = [ 0 , beam_width , A , A_i , B ] ,
638654 parallel_iterations = parallel_iterations , swap_memory = swap_memory
639655 )
640656
641657 return time + 1 , total , B
642658
643659 _ , _ , B = tf .while_loop (
644660 condition , body ,
645- loop_vars = ( 0 , total , B ) ,
661+ loop_vars = [ 0 , total , B ] ,
646662 parallel_iterations = parallel_iterations , swap_memory = swap_memory
647663 )
648664
@@ -665,9 +681,10 @@ def predict_body(pred, A, A_i, B):
665681
666682 # -------------------------------- TFLITE -------------------------------------
667683
668- def make_tflite_function (self , greedy : bool = True ):
684+ def make_tflite_function (self , timestamp : bool = False ):
685+ tflite_func = self .recognize_tflite_with_timestamp if timestamp else self .recognize_tflite
669686 return tf .function (
670- self . recognize_tflite ,
687+ tflite_func ,
671688 input_signature = [
672689 tf .TensorSpec ([None ], dtype = tf .float32 ),
673690 tf .TensorSpec ([], dtype = tf .int32 ),
0 commit comments