1717
1818from .layers .subsampling import TimeReduction
1919from .transducer import Transducer
20- from ..utils .utils import get_rnn , merge_two_last_dims
20+ from ..utils .utils import get_rnn , merge_two_last_dims , shape_list
2121
2222
2323class Reshape (tf .keras .layers .Layer ):
@@ -127,7 +127,7 @@ def __init__(self,
127127 reduction_factor = reductions .get (i , 0 )
128128 if reduction_factor > 0 : self .time_reduction_factor *= reduction_factor
129129
130- def get_initial_state (self ):
130+ def get_initial_state (self , batch_size = 1 ):
131131 """Get zeros states
132132
133133 Returns:
@@ -138,7 +138,7 @@ def get_initial_state(self):
138138 states .append (
139139 tf .stack (
140140 block .rnn .get_initial_state (
141- tf .zeros ([1 , 1 , 1 ], dtype = tf .float32 )
141+ tf .zeros ([batch_size , 1 , 1 ], dtype = tf .float32 )
142142 ), axis = 0
143143 )
144144 )
@@ -269,7 +269,8 @@ def recognize(self,
269269 Returns:
270270 tf.Tensor: a batch of decoded transcripts
271271 """
272- encoded , _ = self .encoder .recognize (features , self .encoder .get_initial_state ())
272+ batch_size , _ , _ , _ = shape_list (features )
273+ encoded , _ = self .encoder .recognize (features , self .encoder .get_initial_state (batch_size ))
273274 return self ._perform_greedy_batch (encoded , input_length ,
274275 parallel_iterations = parallel_iterations , swap_memory = swap_memory )
275276
@@ -335,7 +336,8 @@ def recognize_beam(self,
335336 Returns:
336337 tf.Tensor: a batch of decoded transcripts
337338 """
338- encoded , _ = self .encoder .recognize (features , self .encoder .get_initial_state ())
339+ batch_size , _ , _ , _ = shape_list (features )
340+ encoded , _ = self .encoder .recognize (features , self .encoder .get_initial_state (batch_size ))
339341 return self ._perform_beam_search_batch (encoded , input_length , lm ,
340342 parallel_iterations = parallel_iterations , swap_memory = swap_memory )
341343
0 commit comments