Skip to content

Commit 8769192

Browse files
committed
✍️ update streaming transducer encoder recognize
1 parent 9a67d87 commit 8769192

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

tensorflow_asr/models/keras/streaming_transducer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from .transducer import Transducer
1919
from ..streaming_transducer import StreamingTransducerEncoder
20+
from ...utils.utils import shape_list
2021

2122

2223
class StreamingTransducer(Transducer):
@@ -113,7 +114,8 @@ def recognize(self,
113114
Returns:
114115
tf.Tensor: a batch of decoded transcripts
115116
"""
116-
encoded, _ = self.encoder.recognize(features, self.encoder.get_initial_state())
117+
batch_size, _, _, _ = shape_list(features)
118+
encoded, _ = self.encoder.recognize(features, self.encoder.get_initial_state(batch_size))
117119
return self._perform_greedy_batch(encoded, input_length,
118120
parallel_iterations=parallel_iterations, swap_memory=swap_memory)
119121

@@ -179,7 +181,8 @@ def recognize_beam(self,
179181
Returns:
180182
tf.Tensor: a batch of decoded transcripts
181183
"""
182-
encoded, _ = self.encoder.recognize(features, self.encoder.get_initial_state())
184+
batch_size, _, _, _ = shape_list(features)
185+
encoded, _ = self.encoder.recognize(features, self.encoder.get_initial_state(batch_size))
183186
return self._perform_beam_search_batch(encoded, input_length, lm,
184187
parallel_iterations=parallel_iterations, swap_memory=swap_memory)
185188

tensorflow_asr/models/streaming_transducer.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from .layers.subsampling import TimeReduction
1919
from .transducer import Transducer
20-
from ..utils.utils import get_rnn, merge_two_last_dims
20+
from ..utils.utils import get_rnn, merge_two_last_dims, shape_list
2121

2222

2323
class Reshape(tf.keras.layers.Layer):
@@ -127,7 +127,7 @@ def __init__(self,
127127
reduction_factor = reductions.get(i, 0)
128128
if reduction_factor > 0: self.time_reduction_factor *= reduction_factor
129129

130-
def get_initial_state(self):
130+
def get_initial_state(self, batch_size=1):
131131
"""Get zeros states
132132
133133
Returns:
@@ -138,7 +138,7 @@ def get_initial_state(self):
138138
states.append(
139139
tf.stack(
140140
block.rnn.get_initial_state(
141-
tf.zeros([1, 1, 1], dtype=tf.float32)
141+
tf.zeros([batch_size, 1, 1], dtype=tf.float32)
142142
), axis=0
143143
)
144144
)
@@ -269,7 +269,8 @@ def recognize(self,
269269
Returns:
270270
tf.Tensor: a batch of decoded transcripts
271271
"""
272-
encoded, _ = self.encoder.recognize(features, self.encoder.get_initial_state())
272+
batch_size, _, _, _ = shape_list(features)
273+
encoded, _ = self.encoder.recognize(features, self.encoder.get_initial_state(batch_size))
273274
return self._perform_greedy_batch(encoded, input_length,
274275
parallel_iterations=parallel_iterations, swap_memory=swap_memory)
275276

@@ -335,7 +336,8 @@ def recognize_beam(self,
335336
Returns:
336337
tf.Tensor: a batch of decoded transcripts
337338
"""
338-
encoded, _ = self.encoder.recognize(features, self.encoder.get_initial_state())
339+
batch_size, _, _, _ = shape_list(features)
340+
encoded, _ = self.encoder.recognize(features, self.encoder.get_initial_state(batch_size))
339341
return self._perform_beam_search_batch(encoded, input_length, lm,
340342
parallel_iterations=parallel_iterations, swap_memory=swap_memory)
341343

0 commit comments

Comments
 (0)