@@ -76,8 +76,8 @@ def call(self, inputs, training=False, **kwargs):
7676
7777 @tf .function
7878 def recognize (self , inputs : Dict [str , tf .Tensor ]):
79- logits = self (inputs [ "inputs" ] , training = False )
80- probs = tf .nn .softmax (logits )
79+ logits = self (inputs , training = False )
80+ probs = tf .nn .softmax (logits [ "logits" ] )
8181
8282 def map_fn (prob ): return tf .numpy_function (self ._perform_greedy , inp = [prob ], Tout = tf .string )
8383
@@ -102,7 +102,8 @@ def recognize_tflite(self, signal):
102102 input_length = shape_util .shape_list (features )[1 ]
103103 input_length = math_util .get_reduced_length (input_length , self .time_reduction_factor )
104104 input_length = tf .expand_dims (input_length , axis = 0 )
105- logits = self (features , training = False )
105+ logits = self .encoder (features , training = False )
106+ logits = self .decoder (logits , training = False )
106107 probs = tf .nn .softmax (logits )
107108 decoded = tf .keras .backend .ctc_decode (
108109 y_pred = probs , input_length = input_length , greedy = True
@@ -115,8 +116,8 @@ def recognize_tflite(self, signal):
115116
116117 @tf .function
117118 def recognize_beam (self , inputs : Dict [str , tf .Tensor ], lm : bool = False ):
118- logits = self (inputs [ "inputs" ] , training = False )
119- probs = tf .nn .softmax (logits )
119+ logits = self (inputs , training = False )
120+ probs = tf .nn .softmax (logits [ "logits" ] )
120121
121122 def map_fn (prob ): return tf .numpy_function (self ._perform_beam_search , inp = [prob , lm ], Tout = tf .string )
122123
@@ -148,7 +149,8 @@ def recognize_beam_tflite(self, signal):
148149 input_length = shape_util .shape_list (features )[1 ]
149150 input_length = math_util .get_reduced_length (input_length , self .time_reduction_factor )
150151 input_length = tf .expand_dims (input_length , axis = 0 )
151- logits = self (features , training = False )
152+ logits = self .encoder (features , training = False )
153+ logits = self .decoder (logits , training = False )
152154 probs = tf .nn .softmax (logits )
153155 decoded = tf .keras .backend .ctc_decode (
154156 y_pred = probs , input_length = input_length , greedy = False ,
0 commit comments