Skip to content

Commit d18aac8

Browse files
committed
⚡ Update Ctc and DeepSpeech2, Supported Jasper
1 parent e05a115 commit d18aac8

23 files changed

+591
-682
lines changed

examples/deepspeech2/README.md

Lines changed: 15 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,19 @@ References: [https://arxiv.org/abs/1512.02595](https://arxiv.org/abs/1512.02595)
66

77
```yaml
88
model_config:
9-
conv_conf:
10-
conv_type: 2
11-
conv_kernels: [[11, 41], [11, 21], [11, 11]]
12-
conv_strides: [[2, 2], [1, 2], [1, 2]]
13-
conv_filters: [32, 32, 96]
14-
conv_dropout: 0
15-
rnn_conf:
16-
rnn_layers: 5
17-
rnn_type: lstm
18-
rnn_units: 512
19-
rnn_bidirectional: True
20-
rnn_rowconv: False
21-
rnn_dropout: 0
22-
fc_conf:
23-
fc_units: [1024]
24-
fc_dropout: 0
9+
conv_type: conv2d
10+
conv_kernels: [[11, 41], [11, 21], [11, 11]]
11+
conv_strides: [[2, 2], [1, 2], [1, 2]]
12+
conv_filters: [32, 32, 96]
13+
conv_dropout: 0.1
14+
rnn_nlayers: 5
15+
rnn_type: lstm
16+
rnn_units: 512
17+
rnn_bidirectional: True
18+
rnn_rowconv: 0
19+
rnn_dropout: 0.1
20+
fc_nlayers: 0
21+
fc_units: 1024
2522
```
2623
2724
## Architecture
@@ -30,24 +27,6 @@ model_config:
3027
3128
## Training and Testing
3229
33-
See `python examples/deepspeech2/run_ds2.py --help`
30+
See `python examples/deepspeech2/train_ds2.py --help`
3431

35-
## Results on VIVOS Dataset
36-
37-
* Features: Spectrogram with `80` frequency channels
38-
* KenLM: `alpha = 2.0` and `beta = 1.0`
39-
* Epochs: `20`
40-
* Train set split ratio: `90:10`
41-
* Augmentation: `None`
42-
* Model architecture: same as [vivos.yaml](./configs/vivos.yml)
43-
44-
**CTC Loss**
45-
46-
<img src="./figs/ds2_vivos_ctc_loss.svg" alt="ds2_vivos_ctc_loss" width="300px" />
47-
48-
**Error rates**
49-
50-
| | WER (%) | CER (%) |
51-
| :-------------- | :------------: | :------------: |
52-
| *BeamSearch* | 43.75243 | 17.991581 |
53-
| *BeamSearch LM* | **20.7561836** | **11.0304441** |
32+
See `python examples/deepspeech2/test_ds2.py --help`

examples/deepspeech2/configs/vivos.yml renamed to examples/deepspeech2/config.yml

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ speech_config:
2424
normalize_per_feature: False
2525

2626
decoder_config:
27-
vocabulary: /mnt/Projects/asrk16/TiramisuASR/vocabularies/vietnamese.txt
27+
vocabulary: ./vocabularies/vietnamese.characters
2828
blank_at_zero: False
2929
beam_width: 500
3030
lm_config:
@@ -33,21 +33,20 @@ decoder_config:
3333
beta: 1.0
3434

3535
model_config:
36-
conv_conf:
37-
conv_type: 2
38-
conv_kernels: [[11, 41], [11, 21], [11, 11]]
39-
conv_strides: [[2, 2], [1, 2], [1, 2]]
40-
conv_filters: [32, 32, 96]
41-
conv_dropout: 0
42-
rnn_conf:
43-
rnn_layers: 5
44-
rnn_type: lstm
45-
rnn_units: 512
46-
rnn_bidirectional: True
47-
rnn_rowconv: False
48-
rnn_dropout: 0
49-
fc_conf:
50-
fc_units: null
36+
name: deepspeech2
37+
conv_type: conv2d
38+
conv_kernels: [[11, 41], [11, 21], [11, 11]]
39+
conv_strides: [[2, 2], [1, 2], [1, 2]]
40+
conv_filters: [32, 32, 96]
41+
conv_dropout: 0.1
42+
rnn_nlayers: 5
43+
rnn_type: lstm
44+
rnn_units: 512
45+
rnn_bidirectional: True
46+
rnn_rowconv: 0
47+
rnn_dropout: 0.1
48+
fc_nlayers: 0
49+
fc_units: 1024
5150

5251
learning_config:
5352
augmentations: null

examples/deepspeech2/figs/ds2_vivos_ctc_loss.svg

Lines changed: 0 additions & 1 deletion
This file was deleted.

examples/deepspeech2/model.py

Whitespace-only changes.

examples/deepspeech2/test_ds2.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
setup_environment()
2020
import tensorflow as tf
2121

22-
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "configs", "vivos.yml")
22+
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
2323

2424
tf.keras.backend.clear_session()
2525

@@ -54,7 +54,7 @@
5454
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
5555
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
5656
from tensorflow_asr.runners.base_runners import BaseTester
57-
from model import DeepSpeech2
57+
from tensorflow_asr.models.deepspeech2 import DeepSpeech2
5858

5959
tf.random.set_seed(0)
6060
assert args.export
@@ -63,13 +63,10 @@
6363
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
6464
text_featurizer = CharFeaturizer(config["decoder_config"])
6565
# Build DS2 model
66-
ds2_model = DeepSpeech2(input_shape=speech_featurizer.shape,
67-
arch_config=config["model_config"],
68-
num_classes=text_featurizer.num_classes,
69-
name="deepspeech2")
66+
ds2_model = DeepSpeech2(**config["model_config"], vocabulary_size=text_featurizer.num_classes)
7067
ds2_model._build(speech_featurizer.shape)
7168
ds2_model.load_weights(args.saved, by_name=True)
72-
ds2_model.summary(line_length=150)
69+
ds2_model.summary(line_length=120)
7370
ds2_model.add_featurizers(speech_featurizer, text_featurizer)
7471

7572
if args.tfrecords:

examples/deepspeech2/train_ds2.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
setup_environment()
2020
import tensorflow as tf
2121

22-
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "configs", "vivos.yml")
22+
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
2323

2424
tf.keras.backend.clear_session()
2525

@@ -60,7 +60,7 @@
6060
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
6161
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
6262
from tensorflow_asr.runners.ctc_runners import CTCTrainer
63-
from model import DeepSpeech2
63+
from tensorflow_asr.models.deepspeech2 import DeepSpeech2
6464

6565
config = UserConfig(DEFAULT_YAML, args.config, learning=True)
6666
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
@@ -100,12 +100,9 @@
100100
ctc_trainer = CTCTrainer(text_featurizer, config["learning_config"]["running_config"])
101101
# Build DS2 model
102102
with ctc_trainer.strategy.scope():
103-
ds2_model = DeepSpeech2(input_shape=speech_featurizer.shape,
104-
arch_config=config["model_config"],
105-
num_classes=text_featurizer.num_classes,
106-
name="deepspeech2")
103+
ds2_model = DeepSpeech2(**config["model_config"], vocabulary_size=text_featurizer.num_classes)
107104
ds2_model._build(speech_featurizer.shape)
108-
ds2_model.summary(line_length=150)
105+
ds2_model.summary(line_length=120)
109106
# Compile
110107
ctc_trainer.compile(ds2_model, config["learning_config"]["optimizer_config"],
111108
max_to_keep=args.max_ckpts)

examples/jasper/README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Jasper
2+
3+
References: [https://arxiv.org/abs/1904.03288](https://arxiv.org/abs/1904.03288)
4+
5+
## Model YAML Config Structure
6+
7+
```yaml
8+
model_config:
9+
10+
```
11+
12+
## Architecture
13+
14+
<img src="./figs/jasper_arch.png" alt="jasper_arch" width="500px" />
15+
16+
## Training and Testing
17+
18+
See `python examples/jasper/train_jasper.py --help`
19+
20+
See `python examples/jasper/test_jasper.py --help`

examples/sadeepspeech2/config.yml renamed to examples/jasper/config.yml

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ speech_config:
2424
normalize_per_feature: False
2525

2626
decoder_config:
27-
vocabulary: /mnt/Projects/asrk16/TiramisuASR/vocabularies/vietnamese.txt
27+
vocabulary: ./vocabularies/vietnamese.characters
2828
blank_at_zero: False
2929
beam_width: 500
3030
lm_config:
@@ -33,20 +33,27 @@ decoder_config:
3333
beta: 1.0
3434

3535
model_config:
36-
subsampling:
37-
filters: 144
38-
kernel_size: 32
39-
strides: 2
40-
att:
41-
layers: 16
42-
head_size: 36
43-
num_heads: 4
44-
ffn_size: 1024
45-
dropout: 0
46-
rnn:
47-
layers: 1
48-
units: 320
49-
dropout: 0
36+
name: jasper
37+
dense: True
38+
first_additional_block_channels: 256
39+
first_additional_block_kernels: 11
40+
first_additional_block_strides: 2
41+
first_additional_block_dilation: 1
42+
first_additional_block_dropout: 0.2
43+
nsubblocks: 3
44+
block_channels: [256, 384, 512, 640, 768]
45+
block_kernels: [11, 13, 17, 21, 25]
46+
block_dropout: [0.2, 0.2, 0.2, 0.3, 0.3]
47+
second_additional_block_channels: 896
48+
second_additional_block_kernels: 1
49+
second_additional_block_strides: 1
50+
second_additional_block_dilation: 2
51+
second_additional_block_dropout: 0.4
52+
third_additional_block_channels: 1024
53+
third_additional_block_kernels: 1
54+
third_additional_block_strides: 1
55+
third_additional_block_dilation: 1
56+
third_additional_block_dropout: 0.4
5057

5158
learning_config:
5259
augmentations: null
@@ -61,14 +68,14 @@ learning_config:
6168
tfrecords_dir: /mnt/Data/ML/ASR/Preprocessed/Vivos/TFRecords
6269

6370
optimizer_config:
64-
name: transformer_adam
71+
class_name: adam
6572
config:
66-
warmup_steps: 10000
73+
learning_rate: 0.0001
6774

6875
running_config:
69-
batch_size: 2
76+
batch_size: 8
7077
num_epochs: 20
71-
outdir: /mnt/Projects/asrk16/trained/local/vivos_self_att_ds2
72-
log_interval_steps: 500
73-
save_interval_steps: 500
74-
eval_interval_steps: 700
78+
outdir: /mnt/Projects/asrk16/trained/local/jasper
79+
log_interval_steps: 400
80+
save_interval_steps: 400
81+
eval_interval_steps: 800
279 KB
Loading

examples/sadeepspeech2/test_sadeepspeech2.py renamed to examples/jasper/test_jasper.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright 2020 Huy Le Nguyen (@usimarit)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
import os
216
import argparse
317
from tensorflow_asr.utils import setup_environment, setup_devices
@@ -7,16 +21,18 @@
721

822
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
923

10-
parser = argparse.ArgumentParser(prog="Self Attention DS2")
24+
tf.keras.backend.clear_session()
25+
26+
parser = argparse.ArgumentParser(prog="Jasper Testing")
1127

1228
parser.add_argument("--config", "-c", type=str, default=DEFAULT_YAML,
1329
help="The file path of model configuration file")
1430

1531
parser.add_argument("--saved", type=str, default=None,
16-
help="Path to saved model")
32+
help="Path to the model file to be exported")
1733

1834
parser.add_argument("--tfrecords", default=False, action="store_true",
19-
help="Whether to use tfrecords")
35+
help="Whether to use tfrecords dataset")
2036

2137
parser.add_argument("--mxp", default=False, action="store_true",
2238
help="Enable mixed precision")
@@ -33,34 +49,25 @@
3349

3450
setup_devices([args.device])
3551

36-
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
37-
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
3852
from tensorflow_asr.configs.user_config import UserConfig
3953
from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset
40-
from model import SelfAttentionDS2
54+
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
55+
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
4156
from tensorflow_asr.runners.base_runners import BaseTester
42-
from ctc_decoders import Scorer
57+
from tensorflow_asr.models.jasper import Jasper
4358

4459
tf.random.set_seed(0)
45-
assert args.saved
60+
assert args.export
4661

4762
config = UserConfig(DEFAULT_YAML, args.config, learning=True)
4863
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
4964
text_featurizer = CharFeaturizer(config["decoder_config"])
50-
51-
text_featurizer.add_scorer(Scorer(**text_featurizer.decoder_config["lm_config"],
52-
vocabulary=text_featurizer.vocab_array))
53-
5465
# Build DS2 model
55-
satt_ds2_model = SelfAttentionDS2(
56-
input_shape=speech_featurizer.shape,
57-
arch_config=config["model_config"],
58-
num_classes=text_featurizer.num_classes
59-
)
60-
satt_ds2_model._build(speech_featurizer.shape)
61-
satt_ds2_model.load_weights(args.saved, by_name=True)
62-
satt_ds2_model.summary(line_length=150)
63-
satt_ds2_model.add_featurizers(speech_featurizer, text_featurizer)
66+
jasper = Jasper(**config["model_config"], vocabulary_size=text_featurizer.num_classes)
67+
jasper._build(speech_featurizer.shape)
68+
jasper.load_weights(args.saved, by_name=True)
69+
jasper.summary(line_length=120)
70+
jasper.add_featurizers(speech_featurizer, text_featurizer)
6471

6572
if args.tfrecords:
6673
test_dataset = ASRTFRecordDataset(
@@ -82,5 +89,5 @@
8289
config=config["learning_config"]["running_config"],
8390
output_name=args.output_name
8491
)
85-
ctc_tester.compile(satt_ds2_model)
92+
ctc_tester.compile(jasper)
8693
ctc_tester.run(test_dataset)

0 commit comments

Comments
 (0)