|
1 | | -# Copyright 2020 Huy Le Nguyen (@usimarit) |
2 | | -# |
3 | | -# Licensed under the Apache License, Version 2.0 (the "License"); |
4 | | -# you may not use this file except in compliance with the License. |
5 | | -# You may obtain a copy of the License at |
6 | | -# |
7 | | -# http://www.apache.org/licenses/LICENSE-2.0 |
8 | | -# |
9 | | -# Unless required by applicable law or agreed to in writing, software |
10 | | -# distributed under the License is distributed on an "AS IS" BASIS, |
11 | | -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | | -# See the License for the specific language governing permissions and |
13 | | -# limitations under the License. |
14 | | -""" |
15 | | -Read https://www.tensorflow.org/api_docs/python/tf/keras/layers/LSTM |
16 | | -to use cuDNN-LSTM |
17 | | -""" |
18 | | -import numpy as np |
19 | | -import tensorflow as tf |
20 | | - |
21 | | -from tensorflow_asr.utils.utils import append_default_keys_dict, get_rnn |
22 | | -from tensorflow_asr.models.layers.row_conv_1d import RowConv1D |
23 | | -from tensorflow_asr.models.layers.sequence_wise_bn import SequenceBatchNorm |
24 | | -from tensorflow_asr.models.layers.transpose_time_major import TransposeTimeMajor |
25 | | -from tensorflow_asr.models.layers.merge_two_last_dims import Merge2LastDims |
26 | | -from tensorflow_asr.models.ctc import CtcModel |
27 | | - |
28 | | -DEFAULT_CONV = { |
29 | | - "conv_type": 2, |
30 | | - "conv_kernels": ((11, 41), (11, 21), (11, 21)), |
31 | | - "conv_strides": ((2, 2), (1, 2), (1, 2)), |
32 | | - "conv_filters": (32, 32, 96), |
33 | | - "conv_dropout": 0.2 |
34 | | -} |
35 | | - |
36 | | -DEFAULT_RNN = { |
37 | | - "rnn_layers": 3, |
38 | | - "rnn_type": "gru", |
39 | | - "rnn_units": 350, |
40 | | - "rnn_activation": "tanh", |
41 | | - "rnn_bidirectional": True, |
42 | | - "rnn_rowconv": False, |
43 | | - "rnn_rowconv_context": 2, |
44 | | - "rnn_dropout": 0.2 |
45 | | -} |
46 | | - |
47 | | -DEFAULT_FC = { |
48 | | - "fc_units": (1024,), |
49 | | - "fc_dropout": 0.2 |
50 | | -} |
51 | | - |
52 | | - |
53 | | -def create_ds2(input_shape: list, arch_config: dict, name: str = "deepspeech2"): |
54 | | - conv_conf = append_default_keys_dict(DEFAULT_CONV, arch_config.get("conv_conf", {})) |
55 | | - rnn_conf = append_default_keys_dict(DEFAULT_RNN, arch_config.get("rnn_conf", {})) |
56 | | - fc_conf = append_default_keys_dict(DEFAULT_FC, arch_config.get("fc_conf", {})) |
57 | | - assert len(conv_conf["conv_strides"]) == \ |
58 | | - len(conv_conf["conv_filters"]) == len(conv_conf["conv_kernels"]) |
59 | | - assert conv_conf["conv_type"] in [1, 2] |
60 | | - assert rnn_conf["rnn_type"] in ["lstm", "gru", "rnn"] |
61 | | - assert conv_conf["conv_dropout"] >= 0.0 and rnn_conf["rnn_dropout"] >= 0.0 |
62 | | - |
63 | | - features = tf.keras.Input(shape=input_shape, name="features") |
64 | | - layer = features |
65 | | - |
66 | | - if conv_conf["conv_type"] == 2: |
67 | | - conv = tf.keras.layers.Conv2D |
68 | | - else: |
69 | | - layer = Merge2LastDims("conv1d_features")(layer) |
70 | | - conv = tf.keras.layers.Conv1D |
71 | | - ker_shape = np.shape(conv_conf["conv_kernels"]) |
72 | | - stride_shape = np.shape(conv_conf["conv_strides"]) |
73 | | - filter_shape = np.shape(conv_conf["conv_filters"]) |
74 | | - assert len(ker_shape) == 1 and len(stride_shape) == 1 and len(filter_shape) == 1 |
75 | | - |
76 | | - # CONV Layers |
77 | | - for i, fil in enumerate(conv_conf["conv_filters"]): |
78 | | - layer = conv(filters=fil, kernel_size=conv_conf["conv_kernels"][i], |
79 | | - strides=conv_conf["conv_strides"][i], padding="same", |
80 | | - activation=None, dtype=tf.float32, name=f"cnn_{i}")(layer) |
81 | | - layer = tf.keras.layers.BatchNormalization(name=f"cnn_bn_{i}")(layer) |
82 | | - layer = tf.keras.layers.ReLU(name=f"cnn_relu_{i}")(layer) |
83 | | - layer = tf.keras.layers.Dropout(conv_conf["conv_dropout"], |
84 | | - name=f"cnn_dropout_{i}")(layer) |
85 | | - |
86 | | - if conv_conf["conv_type"] == 2: |
87 | | - layer = Merge2LastDims("reshape_conv2d_to_rnn")(layer) |
88 | | - |
89 | | - rnn = get_rnn(rnn_conf["rnn_type"]) |
90 | | - |
91 | | - # To time major |
92 | | - if rnn_conf["rnn_bidirectional"]: |
93 | | - layer = TransposeTimeMajor("transpose_to_time_major")(layer) |
94 | | - |
95 | | - # RNN layers |
96 | | - for i in range(rnn_conf["rnn_layers"]): |
97 | | - if rnn_conf["rnn_bidirectional"]: |
98 | | - layer = tf.keras.layers.Bidirectional( |
99 | | - rnn(rnn_conf["rnn_units"], activation=rnn_conf["rnn_activation"], |
100 | | - time_major=True, dropout=rnn_conf["rnn_dropout"], |
101 | | - return_sequences=True, use_bias=True), |
102 | | - name=f"b{rnn_conf['rnn_type']}_{i}")(layer) |
103 | | - layer = SequenceBatchNorm(time_major=True, name=f"sequence_wise_bn_{i}")(layer) |
104 | | - else: |
105 | | - layer = rnn(rnn_conf["rnn_units"], activation=rnn_conf["rnn_activation"], |
106 | | - dropout=rnn_conf["rnn_dropout"], return_sequences=True, use_bias=True, |
107 | | - name=f"{rnn_conf['rnn_type']}_{i}")(layer) |
108 | | - layer = SequenceBatchNorm(time_major=False, name=f"sequence_wise_bn_{i}")(layer) |
109 | | - if rnn_conf["rnn_rowconv"]: |
110 | | - layer = RowConv1D(filters=rnn_conf["rnn_units"], |
111 | | - future_context=rnn_conf["rnn_rowconv_context"], |
112 | | - name=f"row_conv_{i}")(layer) |
113 | | - |
114 | | - # To batch major |
115 | | - if rnn_conf["rnn_bidirectional"]: |
116 | | - layer = TransposeTimeMajor("transpose_to_batch_major")(layer) |
117 | | - |
118 | | - # FC Layers |
119 | | - if fc_conf["fc_units"]: |
120 | | - assert fc_conf["fc_dropout"] >= 0.0 |
121 | | - |
122 | | - for idx, units in enumerate(fc_conf["fc_units"]): |
123 | | - layer = tf.keras.layers.Dense(units=units, activation=None, |
124 | | - use_bias=True, name=f"hidden_fc_{idx}")(layer) |
125 | | - layer = tf.keras.layers.BatchNormalization(name=f"hidden_fc_bn_{idx}")(layer) |
126 | | - layer = tf.keras.layers.ReLU(name=f"hidden_fc_relu_{idx}")(layer) |
127 | | - layer = tf.keras.layers.Dropout(fc_conf["fc_dropout"], |
128 | | - name=f"hidden_fc_dropout_{idx}")(layer) |
129 | | - |
130 | | - return tf.keras.Model(inputs=features, outputs=layer, name=name) |
131 | | - |
132 | | - |
133 | | -class DeepSpeech2(CtcModel): |
134 | | - def __init__(self, |
135 | | - input_shape: list, |
136 | | - arch_config: dict, |
137 | | - num_classes: int, |
138 | | - name: str = "deepspeech2"): |
139 | | - super(DeepSpeech2, self).__init__( |
140 | | - base_model=create_ds2(input_shape=input_shape, |
141 | | - arch_config=arch_config, |
142 | | - name=name), |
143 | | - num_classes=num_classes, |
144 | | - name=f"{name}_ctc" |
145 | | - ) |
146 | | - self.time_reduction_factor = 1 |
147 | | - for s in arch_config["conv_conf"]["conv_strides"]: |
148 | | - self.time_reduction_factor *= s[0] |
0 commit comments