1414""" https://arxiv.org/pdf/1811.06621.pdf """
1515
1616import collections
17- from typing import Optional
1817import tensorflow as tf
1918
2019from . import Model
@@ -285,22 +284,16 @@ def call(self, inputs, training=False, **kwargs):
285284 outputs = self .joint_net ([enc , pred ], training = training , ** kwargs )
286285 return outputs
287286
288- def encoder_inference (self ,
289- features : tf .Tensor ,
290- input_length : Optional [tf .Tensor ] = None ,
291- with_batch : Optional [bool ] = False ):
287+ def encoder_inference (self , features : tf .Tensor ):
292288 """Infer function for encoder (or encoders)
293289
294290 Args:
295291 features (tf.Tensor): features with shape [T, F, C]
296- input_length (tf.Tensor): optional features length with shape []
297- with_batch (bool): indicates whether the features included batch dim or not
298292
299293 Returns:
300294 tf.Tensor: output of encoders with shape [T, E]
301295 """
302296 with tf .name_scope (f"{ self .name } _encoder" ):
303- if with_batch : return self .encoder (features , training = False )
304297 outputs = tf .expand_dims (features , axis = 0 )
305298 outputs = self .encoder (outputs , training = False )
306299 return tf .squeeze (outputs , axis = 0 )
@@ -321,7 +314,7 @@ def decoder_inference(self, encoded: tf.Tensor, predicted: tf.Tensor, states: tf
321314 predicted = tf .reshape (predicted , [1 , 1 ]) # [] => [1, 1]
322315 y , new_states = self .predict_net .recognize (predicted , states ) # [1, 1, P], states
323316 ytu = tf .nn .log_softmax (self .joint_net ([encoded , y ], training = False )) # [1, 1, V]
324- ytu = tf .squeeze (ytu , axis = None ) # [1, 1, V] => [V]
317+ ytu = tf .reshape (ytu , shape = [ - 1 ] ) # [1, 1, V] => [V]
325318 return ytu , new_states
326319
327320 def get_config (self ):
@@ -347,7 +340,7 @@ def recognize(self,
347340 Returns:
348341 tf.Tensor: a batch of decoded transcripts
349342 """
350- encoded = self .encoder_inference (features , input_length , with_batch = True )
343+ encoded = self .encoder (features , training = True )
351344 return self ._perform_greedy_batch (encoded , input_length ,
352345 parallel_iterations = parallel_iterations , swap_memory = swap_memory )
353346
@@ -368,11 +361,7 @@ def recognize_tflite(self, signal, predicted, states):
368361 encoded = self .encoder_inference (features )
369362 hypothesis = self ._perform_greedy (encoded , tf .shape (encoded )[0 ], predicted , states )
370363 transcript = self .text_featurizer .indices2upoints (hypothesis .prediction )
371- return (
372- transcript ,
373- hypothesis .prediction [- 1 ],
374- hypothesis .states
375- )
364+ return transcript , hypothesis .index , hypothesis .states
376365
377366 def recognize_tflite_with_timestamp (self , signal , predicted , states ):
378367 features = self .speech_featurizer .tf_extract (signal )
@@ -395,7 +384,7 @@ def recognize_tflite_with_timestamp(self, signal, predicted, states):
395384 non_blank_stime = tf .gather_nd (tf .repeat (tf .expand_dims (stime , axis = - 1 ), tf .shape (upoints )[- 1 ], axis = - 1 ), non_blank )
396385 non_blank_etime = tf .gather_nd (tf .repeat (tf .expand_dims (etime , axis = - 1 ), tf .shape (upoints )[- 1 ], axis = - 1 ), non_blank )
397386
398- return non_blank_transcript , non_blank_stime , non_blank_etime , hypothesis .prediction , hypothesis .states
387+ return non_blank_transcript , non_blank_stime , non_blank_etime , hypothesis .index , hypothesis .states
399388
400389 def _perform_greedy_batch (self ,
401390 encoded : tf .Tensor ,
@@ -450,48 +439,47 @@ def _perform_greedy(self,
450439 total = encoded_length
451440
452441 hypothesis = Hypothesis (
453- index = tf .constant (self .text_featurizer .blank , dtype = tf .int32 ),
454- prediction = tf .ones ([total ], dtype = tf .int32 ) * self .text_featurizer .blank ,
442+ index = predicted ,
443+ prediction = tf .TensorArray (
444+ dtype = tf .int32 , size = total , dynamic_size = False ,
445+ clear_after_read = False , element_shape = tf .TensorShape ([])
446+ ),
455447 states = states
456448 )
457449
458- def condition (time , total , encoded , hypothesis ): return tf .less (time , total )
450+ def condition (_time , _total , _encoded , _hypothesis ): return tf .less (_time , _total )
459451
460- def body (time , total , encoded , hypothesis ):
461- ytu , states = self .decoder_inference (
452+ def body (_time , _total , _encoded , _hypothesis ):
453+ ytu , _states = self .decoder_inference (
462454 # avoid using [index] in tflite
463- encoded = tf .gather_nd (encoded , tf .expand_dims ( time , axis = - 1 )),
464- predicted = hypothesis .index ,
465- states = hypothesis .states
455+ encoded = tf .gather_nd (_encoded , tf .reshape ( _time , shape = [ 1 ] )),
456+ predicted = _hypothesis .index ,
457+ states = _hypothesis .states
466458 )
467- predict = tf .argmax (ytu , axis = - 1 , output_type = tf .int32 ) # => argmax []
459+ _predict = tf .argmax (ytu , axis = - 1 , output_type = tf .int32 ) # => argmax []
468460
469- index , predict , states = tf .cond (
470- tf .equal (predict , self .text_featurizer .blank ),
471- true_fn = lambda : (hypothesis .index , predict , hypothesis .states ),
472- false_fn = lambda : (predict , predict , states ) # update if the new prediction is a non-blank
473- )
461+ # something is wrong with tflite that drop support for tf.cond
462+ # def equal_blank_fn(): return _hypothesis.index, _hypothesis.states
463+ # def non_equal_blank_fn(): return _predict, _states # update if the new prediction is a non-blank
464+ # _index, _states = tf.cond(tf.equal(_predict, blank), equal_blank_fn, non_equal_blank_fn)
474465
475- hypothesis = Hypothesis (
476- index = index ,
477- prediction = tf .tensor_scatter_nd_update (
478- hypothesis .prediction ,
479- indices = tf .reshape (time , [1 , 1 ]),
480- updates = tf .expand_dims (predict , axis = - 1 )
481- ),
482- states = states
483- )
466+ _equal = tf .equal (_predict , self .text_featurizer .blank )
467+ _index = tf .where (_equal , _hypothesis .index , _predict )
468+ _states = tf .where (_equal , _hypothesis .states , _states )
469+
470+ _prediction = _hypothesis .prediction .write (_time , _predict )
471+ _hypothesis = Hypothesis (index = _index , prediction = _prediction , states = _states )
484472
485- return time + 1 , total , encoded , hypothesis
473+ return _time + 1 , _total , _encoded , _hypothesis
486474
487- time , total , encoded , hypothesis = tf .while_loop (
475+ _ , _ , _ , hypothesis = tf .while_loop (
488476 condition , body ,
489477 loop_vars = [time , total , encoded , hypothesis ],
490478 parallel_iterations = parallel_iterations ,
491479 swap_memory = swap_memory
492480 )
493481
494- return hypothesis
482+ return Hypothesis ( index = hypothesis . index , prediction = hypothesis . prediction . stack (), states = hypothesis . states )
495483
496484 # -------------------------------- BEAM SEARCH -------------------------------------
497485
@@ -511,7 +499,7 @@ def recognize_beam(self,
511499 Returns:
512500 tf.Tensor: a batch of decoded transcripts
513501 """
514- encoded = self .encoder_inference (features , input_length , with_batch = True )
502+ encoded = self .encoder (features , training = True )
515503 return self ._perform_beam_search_batch (encoded , input_length , lm ,
516504 parallel_iterations = parallel_iterations , swap_memory = swap_memory )
517505
0 commit comments