Skip to content

Commit e05a115

Browse files
committed
⚡ Update deepspeech2, add jasper
1 parent ee4c314 commit e05a115

File tree

5 files changed

+247
-164
lines changed

5 files changed

+247
-164
lines changed

examples/deepspeech2/model.py

Lines changed: 0 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -1,148 +0,0 @@
1-
# Copyright 2020 Huy Le Nguyen (@usimarit)
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
4-
# you may not use this file except in compliance with the License.
5-
# You may obtain a copy of the License at
6-
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# Unless required by applicable law or agreed to in writing, software
10-
# distributed under the License is distributed on an "AS IS" BASIS,
11-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
# See the License for the specific language governing permissions and
13-
# limitations under the License.
14-
"""
15-
Read https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTM
16-
to use cuDNN-LSTM
17-
"""
18-
import numpy as np
19-
import tensorflow as tf
20-
21-
from tensorflow_asr.utils.utils import append_default_keys_dict, get_rnn
22-
from tensorflow_asr.models.layers.row_conv_1d import RowConv1D
23-
from tensorflow_asr.models.layers.sequence_wise_bn import SequenceBatchNorm
24-
from tensorflow_asr.models.layers.transpose_time_major import TransposeTimeMajor
25-
from tensorflow_asr.models.layers.merge_two_last_dims import Merge2LastDims
26-
from tensorflow_asr.models.ctc import CtcModel
27-
28-
DEFAULT_CONV = {
29-
"conv_type": 2,
30-
"conv_kernels": ((11, 41), (11, 21), (11, 21)),
31-
"conv_strides": ((2, 2), (1, 2), (1, 2)),
32-
"conv_filters": (32, 32, 96),
33-
"conv_dropout": 0.2
34-
}
35-
36-
DEFAULT_RNN = {
37-
"rnn_layers": 3,
38-
"rnn_type": "gru",
39-
"rnn_units": 350,
40-
"rnn_activation": "tanh",
41-
"rnn_bidirectional": True,
42-
"rnn_rowconv": False,
43-
"rnn_rowconv_context": 2,
44-
"rnn_dropout": 0.2
45-
}
46-
47-
DEFAULT_FC = {
48-
"fc_units": (1024,),
49-
"fc_dropout": 0.2
50-
}
51-
52-
53-
def create_ds2(input_shape: list, arch_config: dict, name: str = "deepspeech2"):
54-
conv_conf = append_default_keys_dict(DEFAULT_CONV, arch_config.get("conv_conf", {}))
55-
rnn_conf = append_default_keys_dict(DEFAULT_RNN, arch_config.get("rnn_conf", {}))
56-
fc_conf = append_default_keys_dict(DEFAULT_FC, arch_config.get("fc_conf", {}))
57-
assert len(conv_conf["conv_strides"]) == \
58-
len(conv_conf["conv_filters"]) == len(conv_conf["conv_kernels"])
59-
assert conv_conf["conv_type"] in [1, 2]
60-
assert rnn_conf["rnn_type"] in ["lstm", "gru", "rnn"]
61-
assert conv_conf["conv_dropout"] >= 0.0 and rnn_conf["rnn_dropout"] >= 0.0
62-
63-
features = tf.keras.Input(shape=input_shape, name="features")
64-
layer = features
65-
66-
if conv_conf["conv_type"] == 2:
67-
conv = tf.keras.layers.Conv2D
68-
else:
69-
layer = Merge2LastDims("conv1d_features")(layer)
70-
conv = tf.keras.layers.Conv1D
71-
ker_shape = np.shape(conv_conf["conv_kernels"])
72-
stride_shape = np.shape(conv_conf["conv_strides"])
73-
filter_shape = np.shape(conv_conf["conv_filters"])
74-
assert len(ker_shape) == 1 and len(stride_shape) == 1 and len(filter_shape) == 1
75-
76-
# CONV Layers
77-
for i, fil in enumerate(conv_conf["conv_filters"]):
78-
layer = conv(filters=fil, kernel_size=conv_conf["conv_kernels"][i],
79-
strides=conv_conf["conv_strides"][i], padding="same",
80-
activation=None, dtype=tf.float32, name=f"cnn_{i}")(layer)
81-
layer = tf.keras.layers.BatchNormalization(name=f"cnn_bn_{i}")(layer)
82-
layer = tf.keras.layers.ReLU(name=f"cnn_relu_{i}")(layer)
83-
layer = tf.keras.layers.Dropout(conv_conf["conv_dropout"],
84-
name=f"cnn_dropout_{i}")(layer)
85-
86-
if conv_conf["conv_type"] == 2:
87-
layer = Merge2LastDims("reshape_conv2d_to_rnn")(layer)
88-
89-
rnn = get_rnn(rnn_conf["rnn_type"])
90-
91-
# To time major
92-
if rnn_conf["rnn_bidirectional"]:
93-
layer = TransposeTimeMajor("transpose_to_time_major")(layer)
94-
95-
# RNN layers
96-
for i in range(rnn_conf["rnn_layers"]):
97-
if rnn_conf["rnn_bidirectional"]:
98-
layer = tf.keras.layers.Bidirectional(
99-
rnn(rnn_conf["rnn_units"], activation=rnn_conf["rnn_activation"],
100-
time_major=True, dropout=rnn_conf["rnn_dropout"],
101-
return_sequences=True, use_bias=True),
102-
name=f"b{rnn_conf['rnn_type']}_{i}")(layer)
103-
layer = SequenceBatchNorm(time_major=True, name=f"sequence_wise_bn_{i}")(layer)
104-
else:
105-
layer = rnn(rnn_conf["rnn_units"], activation=rnn_conf["rnn_activation"],
106-
dropout=rnn_conf["rnn_dropout"], return_sequences=True, use_bias=True,
107-
name=f"{rnn_conf['rnn_type']}_{i}")(layer)
108-
layer = SequenceBatchNorm(time_major=False, name=f"sequence_wise_bn_{i}")(layer)
109-
if rnn_conf["rnn_rowconv"]:
110-
layer = RowConv1D(filters=rnn_conf["rnn_units"],
111-
future_context=rnn_conf["rnn_rowconv_context"],
112-
name=f"row_conv_{i}")(layer)
113-
114-
# To batch major
115-
if rnn_conf["rnn_bidirectional"]:
116-
layer = TransposeTimeMajor("transpose_to_batch_major")(layer)
117-
118-
# FC Layers
119-
if fc_conf["fc_units"]:
120-
assert fc_conf["fc_dropout"] >= 0.0
121-
122-
for idx, units in enumerate(fc_conf["fc_units"]):
123-
layer = tf.keras.layers.Dense(units=units, activation=None,
124-
use_bias=True, name=f"hidden_fc_{idx}")(layer)
125-
layer = tf.keras.layers.BatchNormalization(name=f"hidden_fc_bn_{idx}")(layer)
126-
layer = tf.keras.layers.ReLU(name=f"hidden_fc_relu_{idx}")(layer)
127-
layer = tf.keras.layers.Dropout(fc_conf["fc_dropout"],
128-
name=f"hidden_fc_dropout_{idx}")(layer)
129-
130-
return tf.keras.Model(inputs=features, outputs=layer, name=name)
131-
132-
133-
class DeepSpeech2(CtcModel):
134-
def __init__(self,
135-
input_shape: list,
136-
arch_config: dict,
137-
num_classes: int,
138-
name: str = "deepspeech2"):
139-
super(DeepSpeech2, self).__init__(
140-
base_model=create_ds2(input_shape=input_shape,
141-
arch_config=arch_config,
142-
name=name),
143-
num_classes=num_classes,
144-
name=f"{name}_ctc"
145-
)
146-
self.time_reduction_factor = 1
147-
for s in arch_config["conv_conf"]["conv_strides"]:
148-
self.time_reduction_factor *= s[0]

tensorflow_asr/models/ctc.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,46 +17,37 @@
1717

1818
from ctc_decoders import ctc_greedy_decoder, ctc_beam_search_decoder
1919

20+
from . import Model
2021
from ..featurizers.speech_featurizers import TFSpeechFeaturizer
2122
from ..featurizers.text_featurizers import TextFeaturizer
2223
from ..utils.utils import shape_list
2324

2425

25-
class CtcModel(tf.keras.Model):
26+
class CtcModel(Model):
2627
def __init__(self,
27-
base_model: tf.keras.Model,
28-
num_classes: int,
28+
vocabulary_size: int,
2929
name="ctc_model",
3030
**kwargs):
3131
super(CtcModel, self).__init__(name=name, **kwargs)
32-
self.base_model = base_model
3332
# Fully connected layer
34-
self.fc = tf.keras.layers.Dense(units=num_classes, activation="linear",
33+
self.fc = tf.keras.layers.Dense(units=vocabulary_size, activation="linear",
3534
use_bias=True, name=f"{name}_fc")
3635

3736
def _build(self, input_shape):
3837
features = tf.keras.Input(input_shape, dtype=tf.float32)
3938
self(features, training=False)
4039

41-
def summary(self, line_length=None, **kwargs):
42-
self.base_model.summary(line_length=line_length, **kwargs)
43-
super(CtcModel, self).summary(line_length, **kwargs)
44-
4540
def add_featurizers(self,
4641
speech_featurizer: TFSpeechFeaturizer,
4742
text_featurizer: TextFeaturizer):
4843
self.speech_featurizer = speech_featurizer
4944
self.text_featurizer = text_featurizer
5045

51-
def call(self, inputs, training=False, **kwargs):
52-
outputs = self.base_model(inputs, training=training)
53-
outputs = self.fc(outputs, training=training)
54-
return outputs
46+
def call(self, inputs, training=False):
47+
return self.fc(inputs, training=training)
5548

5649
def get_config(self):
57-
config = self.base_model.get_config()
58-
config.update(self.fc.get_config())
59-
return config
50+
return self.fc.get_config()
6051

6152
# -------------------------------- GREEDY -------------------------------------
6253

0 commit comments

Comments
 (0)