Skip to content

Commit 1372623

Browse files
committed
🚀 Add streaming recurrent neural network transducer
1 parent 6729e44 commit 1372623

File tree

3 files changed

+364
-36
lines changed

3 files changed

+364
-36
lines changed

tensorflow_asr/models/layers/subsampling.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,24 @@
1717
from ...utils.utils import merge_two_last_dims
1818

1919

20+
class TimeReduction(tf.keras.layers.Layer):
21+
def __init__(self, factor: int, name: str = "TimeReduction", **kwargs):
22+
super(TimeReduction, self).__init__(name=name, **kwargs)
23+
self.factor = factor
24+
25+
def call(self, inputs, **kwargs):
26+
# Ref: https://github.com/noahchalifour/rnnt-speech-recognition/blob/master/model.py
27+
outputs = merge_two_last_dims(inputs)
28+
shape = tf.shape(outputs)
29+
outputs = tf.pad(inputs, [[0, 0], [0, tf.math.floormod(shape[1], self.factor)], [0, 0]])
30+
return tf.reshape(outputs, [shape[0], -1, shape[-1] * self.factor])
31+
32+
def get_config(self):
33+
config = super(TimeReduction, self).get_config()
34+
config.update({"factor": self.factor})
35+
return config
36+
37+
2038
class VggSubsampling(tf.keras.layers.Layer):
2139
def __init__(self,
2240
filters: tuple or list = (32, 64),

tensorflow_asr/models/layers/time_reduction.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

0 commit comments

Comments
 (0)