Skip to content

Commit ade7891

Browse files
committed
Merge branch 'dev/rnnt' into main
2 parents 72cd5d2 + 6c107e3 commit ade7891

File tree

11 files changed

+893
-93
lines changed

11 files changed

+893
-93
lines changed

examples/conformer/train_ga_subword_conformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
config["decoder_config"],
8383
corpus_files=args.subwords_corpus
8484
)
85-
text_featurizer.subwords.save_to_file(args.subwords_prefix)
85+
text_featurizer.save_to_file(args.subwords)
8686

8787
if args.tfrecords:
8888
train_dataset = ASRTFRecordDataset(

examples/conformer/train_subword_conformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
config["decoder_config"],
8383
corpus_files=args.subwords_corpus
8484
)
85-
text_featurizer.subwords.save_to_file(args.subwords_prefix)
85+
text_featurizer.save_to_file(args.subwords)
8686

8787
if args.tfrecords:
8888
train_dataset = ASRTFRecordDataset(

examples/streaming_transducer/config.yml

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,21 @@ decoder_config:
3333

3434
model_config:
3535
name: streaming_transducer
36-
reduction_factor: 2
37-
reduction_positions: [1]
38-
encoder_dim: 320
39-
encoder_units: 1024
40-
encoder_layers: 8
36+
encoder_reductions:
37+
0: 3
38+
1: 2
39+
encoder_dmodel: 320
40+
encoder_rnn_type: lstm
41+
encoder_rnn_units: 1024
42+
encoder_nlayers: 8
4143
encoder_layer_norm: True
42-
encoder_type: lstm
43-
embed_dim: 320
44-
embed_dropout: 0.1
45-
num_rnns: 1
46-
rnn_units: 320
47-
rnn_type: lstm
48-
layer_norm: True
44+
prediction_embed_dim: 320
45+
prediction_embed_dropout: 0.0
46+
prediction_num_rnns: 2
47+
prediction_rnn_units: 1024
48+
prediction_rnn_type: lstm
49+
prediction_projection_units: 320
50+
prediction_layer_norm: True
4951
joint_dim: 320
5052

5153
learning_config:

examples/streaming_transducer/train_ga_subword_streaming_transducer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
config["decoder_config"],
8181
corpus_files=args.subwords_corpus
8282
)
83-
text_featurizer.subwords.save_to_file(args.subwords_prefix)
83+
text_featurizer.save_to_file(args.subwords)
8484

8585
if args.tfrecords:
8686
train_dataset = ASRTFRecordDataset(

examples/streaming_transducer/train_subword_streaming_transducer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
config["decoder_config"],
8181
corpus_files=args.subwords_corpus
8282
)
83-
text_featurizer.subwords.save_to_file(args.subwords_prefix)
83+
text_featurizer.save_to_file(args.subwords)
8484

8585
if args.tfrecords:
8686
train_dataset = ASRTFRecordDataset(

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737

3838
setuptools.setup(
3939
name="TensorFlowASR",
40-
version="0.2.7",
40+
version="0.2.8",
4141
author="Huy Le Nguyen",
4242
author_email="[email protected]",
4343
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",

tensorflow_asr/featurizers/text_featurizers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,13 @@ def load_from_file(cls, decoder_config: dict, filename: str = None):
236236
subwords = tds.features.text.SubwordTextEncoder.load_from_file(filename_prefix)
237237
return cls(decoder_config, subwords)
238238

239+
def save_to_file(self, filename: str = None):
240+
if filename is not None:
241+
filename_prefix = os.path.splitext(preprocess_paths(filename))[0]
242+
else:
243+
filename_prefix = self.decoder_config.get("vocabulary", None)
244+
return self.subwords.save_to_file(filename_prefix)
245+
239246
def extract(self, text: str) -> tf.Tensor:
240247
"""
241248
Convert string to a list of integers

tensorflow_asr/models/streaming_transducer.py

Lines changed: 68 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -16,57 +16,55 @@
1616
import collections
1717
import tensorflow as tf
1818

19-
from .layers.merge_two_last_dims import Merge2LastDims
2019
from .layers.subsampling import TimeReduction
2120
from .transducer import Transducer, BeamHypothesis
22-
from ..utils.utils import get_rnn, get_shape_invariants
21+
from ..utils.utils import get_rnn, get_shape_invariants, merge_two_last_dims
2322

2423
Hypothesis = collections.namedtuple(
2524
"Hypothesis",
2625
("index", "prediction", "encoder_states", "prediction_states")
2726
)
2827

2928

29+
class Reshape(tf.keras.layers.Layer):
30+
def call(self, inputs): return merge_two_last_dims(inputs)
31+
32+
3033
class StreamingTransducerBlock(tf.keras.Model):
3134
def __init__(self,
32-
reduction_factor: int = 3,
33-
apply_reduction: bool = False,
34-
encoder_dim: int = 320,
35-
encoder_type: str = "lstm",
36-
encoder_units: int = 1024,
37-
encoder_layer_norm: bool = True,
38-
apply_projection: bool = True,
35+
reduction_factor: int = 0,
36+
dmodel: int = 640,
37+
rnn_type: str = "lstm",
38+
rnn_units: int = 2048,
39+
layer_norm: bool = True,
3940
kernel_regularizer=None,
4041
bias_regularizer=None,
4142
**kwargs):
4243
super(StreamingTransducerBlock, self).__init__(**kwargs)
4344

44-
if apply_reduction:
45+
if reduction_factor > 0:
4546
self.reduction = TimeReduction(reduction_factor, name=f"{self.name}_reduction")
4647
else:
4748
self.reduction = None
4849

49-
RNN = get_rnn(encoder_type)
50+
RNN = get_rnn(rnn_type)
5051
self.rnn = RNN(
51-
units=encoder_units, return_sequences=True,
52+
units=rnn_units, return_sequences=True,
5253
name=f"{self.name}_rnn", return_state=True,
5354
kernel_regularizer=kernel_regularizer,
5455
bias_regularizer=bias_regularizer
5556
)
5657

57-
if encoder_layer_norm:
58+
if layer_norm:
5859
self.ln = tf.keras.layers.LayerNormalization(name=f"{self.name}_ln")
5960
else:
6061
self.ln = None
6162

62-
if apply_projection:
63-
self.projection = tf.keras.layers.Dense(
64-
encoder_dim, name=f"{self.name}_projection",
65-
kernel_regularizer=kernel_regularizer,
66-
bias_regularizer=bias_regularizer
67-
)
68-
else:
69-
self.projection = None
63+
self.projection = tf.keras.layers.Dense(
64+
dmodel, name=f"{self.name}_projection",
65+
kernel_regularizer=kernel_regularizer,
66+
bias_regularizer=bias_regularizer
67+
)
7068

7169
def call(self, inputs, training=False):
7270
outputs = inputs
@@ -76,8 +74,7 @@ def call(self, inputs, training=False):
7674
outputs = outputs[0]
7775
if self.ln is not None:
7876
outputs = self.ln(outputs, training=training)
79-
if self.projection is not None:
80-
outputs = self.projection(outputs, training=training)
77+
outputs = self.projection(outputs, training=training)
8178
return outputs
8279

8380
def recognize(self, inputs, states):
@@ -89,12 +86,11 @@ def recognize(self, inputs, states):
8986
outputs = outputs[0]
9087
if self.ln is not None:
9188
outputs = self.ln(outputs, training=False)
92-
if self.projection is not None:
93-
outputs = self.projection(outputs, training=False)
89+
outputs = self.projection(outputs, training=False)
9490
return outputs, new_states
9591

9692
def get_config(self):
97-
conf = super(StreamingTransducerBlock, self).get_config()
93+
conf = {}
9894
if self.reduction is not None:
9995
conf.update(self.reduction.get_config())
10096
conf.update(self.rnn.get_config())
@@ -106,38 +102,36 @@ def get_config(self):
106102

107103
class StreamingTransducerEncoder(tf.keras.Model):
108104
def __init__(self,
109-
reduction_factor: int = 3,
110-
reduction_positions: list = [1],
111-
encoder_dim: int = 320,
112-
encoder_layers: int = 8,
113-
encoder_type: str = "lstm",
114-
encoder_units: int = 1024,
115-
encoder_layer_norm: bool = True,
105+
reductions: dict = {0: 3, 1: 2},
106+
dmodel: int = 640,
107+
nlayers: int = 8,
108+
rnn_type: str = "lstm",
109+
rnn_units: int = 2048,
110+
layer_norm: bool = True,
116111
kernel_regularizer=None,
117112
bias_regularizer=None,
118113
**kwargs):
119114
super(StreamingTransducerEncoder, self).__init__(**kwargs)
120115

121-
self.merge = Merge2LastDims(name=f"{self.name}_merge")
116+
self.reshape = Reshape(name=f"{self.name}_reshape")
122117

123118
self.blocks = [
124119
StreamingTransducerBlock(
125-
reduction_factor=reduction_factor,
126-
apply_reduction=(i in reduction_positions),
127-
apply_projection=(i != encoder_layers - 1),
128-
encoder_dim=encoder_dim,
129-
encoder_type=encoder_type,
130-
encoder_units=encoder_units,
131-
encoder_layer_norm=encoder_layer_norm,
120+
reduction_factor=reductions.get(i, 0), # key is index, value is the factor
121+
dmodel=dmodel,
122+
rnn_type=rnn_type,
123+
rnn_units=rnn_units,
124+
layer_norm=layer_norm,
132125
kernel_regularizer=kernel_regularizer,
133126
bias_regularizer=bias_regularizer,
134-
name=f"{self.name}_{i}"
135-
) for i in range(encoder_layers)
127+
name=f"{self.name}_block_{i}"
128+
) for i in range(nlayers)
136129
]
137130

138131
self.time_reduction_factor = 1
139-
for i in range(encoder_layers):
140-
if i in reduction_positions: self.time_reduction_factor *= reduction_factor
132+
for i in range(nlayers):
133+
reduction_factor = reductions.get(i, 0)
134+
if reduction_factor > 0: self.time_reduction_factor *= reduction_factor
141135

142136
def get_initial_state(self):
143137
"""Get zeros states
@@ -157,7 +151,7 @@ def get_initial_state(self):
157151
return tf.stack(states, axis=0)
158152

159153
def call(self, inputs, training=False):
160-
outputs = self.merge(inputs)
154+
outputs = self.reshape(inputs)
161155
for block in self.blocks:
162156
outputs = block(outputs, training=training)
163157
return outputs
@@ -173,60 +167,60 @@ def recognize(self, inputs, states):
173167
tf.Tensor: outputs with shape [1, T, E]
174168
tf.Tensor: new states with shape [num_lstms, 1 or 2, 1, P]
175169
"""
176-
outputs = self.merge(inputs)
170+
outputs = self.reshape(inputs)
177171
new_states = []
178172
for i, block in enumerate(self.blocks):
179173
outputs, block_states = block.recognize(outputs, states=tf.unstack(states[i], axis=0))
180174
new_states.append(block_states)
181175
return outputs, tf.stack(new_states, axis=0)
182176

183177
def get_config(self):
184-
conf = {}
178+
conf = self.reshape.get_config()
185179
for block in self.blocks: conf.update(block.get_config())
186180
return conf
187181

188182

189183
class StreamingTransducer(Transducer):
190184
def __init__(self,
191185
vocabulary_size: int,
192-
reduction_factor: int = 2,
193-
reduction_positions: list = [1],
194-
encoder_dim: int = 320,
195-
encoder_layers: int = 8,
196-
encoder_type: str = "lstm",
197-
encoder_units: int = 1024,
186+
encoder_reductions: dict = {0: 3, 1: 2},
187+
encoder_dmodel: int = 640,
188+
encoder_nlayers: int = 8,
189+
encoder_rnn_type: str = "lstm",
190+
encoder_rnn_units: int = 2048,
198191
encoder_layer_norm: bool = True,
199-
embed_dim: int = 320,
200-
embed_dropout: float = 0,
201-
num_rnns: int = 2,
202-
rnn_units: int = 1024,
203-
rnn_type: str = "lstm",
204-
layer_norm: bool = True,
205-
joint_dim: int = 320,
192+
prediction_embed_dim: int = 320,
193+
prediction_embed_dropout: float = 0,
194+
prediction_num_rnns: int = 2,
195+
prediction_rnn_units: int = 2048,
196+
prediction_rnn_type: str = "lstm",
197+
prediction_layer_norm: bool = True,
198+
prediction_projection_units: int = 640,
199+
joint_dim: int = 640,
206200
kernel_regularizer = None,
207201
bias_regularizer = None,
208202
name = "StreamingTransducer",
209203
**kwargs):
210204
super(StreamingTransducer, self).__init__(
211205
encoder=StreamingTransducerEncoder(
212-
reduction_factor=reduction_factor,
213-
reduction_positions=reduction_positions,
214-
encoder_dim=encoder_dim,
215-
encoder_layers=encoder_layers,
216-
encoder_type=encoder_type,
217-
encoder_units=encoder_units,
218-
encoder_layer_norm=encoder_layer_norm,
206+
reductions=encoder_reductions,
207+
dmodel=encoder_dmodel,
208+
nlayers=encoder_nlayers,
209+
rnn_type=encoder_rnn_type,
210+
rnn_units=encoder_rnn_units,
211+
layer_norm=encoder_layer_norm,
219212
kernel_regularizer=kernel_regularizer,
220213
bias_regularizer=bias_regularizer,
221214
name=f"{name}_encoder"
222215
),
223216
vocabulary_size=vocabulary_size,
224-
embed_dim=embed_dim,
225-
embed_dropout=embed_dropout,
226-
num_rnns=num_rnns,
227-
rnn_units=rnn_units,
228-
rnn_type=rnn_type,
229-
layer_norm=layer_norm,
217+
embed_dim=prediction_embed_dim,
218+
embed_dropout=prediction_embed_dropout,
219+
num_rnns=prediction_num_rnns,
220+
rnn_units=prediction_rnn_units,
221+
rnn_type=prediction_rnn_type,
222+
layer_norm=prediction_layer_norm,
223+
projection_units=prediction_projection_units,
230224
joint_dim=joint_dim,
231225
kernel_regularizer=kernel_regularizer,
232226
bias_regularizer=bias_regularizer,

0 commit comments

Comments
 (0)