1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from typing import Optional
1516import numpy as np
1617import tensorflow as tf
1718
1819from . import Model
1920from ..featurizers .speech_featurizers import TFSpeechFeaturizer
2021from ..featurizers .text_featurizers import TextFeaturizer
21- from ..utils .utils import shape_list
22+ from ..utils .utils import shape_list , get_reduced_length
2223
2324
2425class CtcModel (Model ):
@@ -41,20 +42,15 @@ def call(self, inputs, training=False, **kwargs):
4142 # -------------------------------- GREEDY -------------------------------------
4243
4344 @tf .function
44- def recognize (self , signals ):
45-
46- def extract_fn (signal ): return self .speech_featurizer .tf_extract (signal )
47-
48- features = tf .map_fn (extract_fn , signals ,
49- fn_output_signature = tf .TensorSpec (self .speech_featurizer .shape , dtype = tf .float32 ))
45+ def recognize (self , features : tf .Tensor , input_length : Optional [tf .Tensor ]):
5046 logits = self (features , training = False )
5147 probs = tf .nn .softmax (logits )
5248
53- def map_fn (prob ): return tf .numpy_function (self .perform_greedy , inp = [prob ], Tout = tf .string )
49+ def map_fn (prob ): return tf .numpy_function (self .__perform_greedy , inp = [prob ], Tout = tf .string )
5450
5551 return tf .map_fn (map_fn , probs , fn_output_signature = tf .TensorSpec ([], dtype = tf .string ))
5652
57- def perform_greedy (self , probs : np .ndarray ):
53+ def __perform_greedy (self , probs : np .ndarray ):
5854 from ctc_decoders import ctc_greedy_decoder
5955 decoded = ctc_greedy_decoder (probs , vocabulary = self .text_featurizer .vocab_array )
6056 return tf .convert_to_tensor (decoded , dtype = tf .string )
@@ -71,7 +67,7 @@ def recognize_tflite(self, signal):
7167 features = self .speech_featurizer .tf_extract (signal )
7268 features = tf .expand_dims (features , axis = 0 )
7369 input_length = shape_list (features )[1 ]
74- input_length = input_length // self .base_model .time_reduction_factor
70+ input_length = get_reduced_length ( input_length , self .base_model .time_reduction_factor )
7571 input_length = tf .expand_dims (input_length , axis = 0 )
7672 logits = self (features , training = False )
7773 probs = tf .nn .softmax (logits )
@@ -85,25 +81,20 @@ def recognize_tflite(self, signal):
8581 # -------------------------------- BEAM SEARCH -------------------------------------
8682
8783 @tf .function
88- def recognize_beam (self , signals , lm = False ):
89-
90- def extract_fn (signal ): return self .speech_featurizer .tf_extract (signal )
91-
92- features = tf .map_fn (extract_fn , signals ,
93- fn_output_signature = tf .TensorSpec (self .speech_featurizer .shape , dtype = tf .float32 ))
84+ def recognize_beam (self , features : tf .Tensor , input_length : Optional [tf .Tensor ], lm : bool = False ):
9485 logits = self (features , training = False )
9586 probs = tf .nn .softmax (logits )
9687
97- def map_fn (prob ): return tf .numpy_function (self .perform_beam_search , inp = [prob , lm ], Tout = tf .string )
88+ def map_fn (prob ): return tf .numpy_function (self .__perform_beam_search , inp = [prob , lm ], Tout = tf .string )
9889
9990 return tf .map_fn (map_fn , probs , dtype = tf .string )
10091
101- def perform_beam_search (self , probs : np .ndarray , lm : bool = False ):
92+ def __perform_beam_search (self , probs : np .ndarray , lm : bool = False ):
10293 from ctc_decoders import ctc_beam_search_decoder
10394 decoded = ctc_beam_search_decoder (
10495 probs_seq = probs ,
10596 vocabulary = self .text_featurizer .vocab_array ,
106- beam_size = self .text_featurizer .decoder_config [ " beam_width" ] ,
97+ beam_size = self .text_featurizer .decoder_config . beam_width ,
10798 ext_scoring_func = self .text_featurizer .scorer if lm else None
10899 )
109100 decoded = decoded [0 ][- 1 ]
@@ -122,13 +113,13 @@ def recognize_beam_tflite(self, signal):
122113 features = self .speech_featurizer .tf_extract (signal )
123114 features = tf .expand_dims (features , axis = 0 )
124115 input_length = shape_list (features )[1 ]
125- input_length = input_length // self .base_model .time_reduction_factor
116+ input_length = get_reduced_length ( input_length , self .base_model .time_reduction_factor )
126117 input_length = tf .expand_dims (input_length , axis = 0 )
127118 logits = self (features , training = False )
128119 probs = tf .nn .softmax (logits )
129120 decoded = tf .keras .backend .ctc_decode (
130121 y_pred = probs , input_length = input_length , greedy = False ,
131- beam_width = self .text_featurizer .decoder_config [ " beam_width" ]
122+ beam_width = self .text_featurizer .decoder_config . beam_width
132123 )
133124 decoded = tf .cast (decoded [0 ][0 ][0 ], dtype = tf .int32 )
134125 transcript = self .text_featurizer .indices2upoints (decoded )
0 commit comments