Skip to content

Commit ffa7c9f

Browse files
committed
⚡ Supported Streaming RNN Transducer
1 parent 1372623 commit ffa7c9f

18 files changed

+1243
-82
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ TensorFlowASR implements some automatic speech recognition architectures such as
1919

2020
## What's New?
2121

22+
- (10/18/2020) Supported Streaming Transducer [https://arxiv.org/abs/1811.06621](https://arxiv.org/abs/1811.06621)
2223
- (10/15/2020) Add gradients accumulation and Refactor to TensorflowASR
2324
- (10/10/2020) Update documents and upload package to pypi
2425
- (10/6/2020) Change `nlpaug` version to `>=1.0.1`
@@ -32,6 +33,8 @@ TensorFlowASR implements some automatic speech recognition architectures such as
3233
- **Transducer Models** (End2end models using RNNT Loss for training)
3334
- **Conformer Transducer** (Reference: [https://arxiv.org/abs/2005.08100](https://arxiv.org/abs/2005.08100))
3435
See [examples/conformer](./examples/conformer)
36+
- **Streaming Transducer** (Reference: [https://arxiv.org/abs/1811.06621](https://arxiv.org/abs/1811.06621))
37+
See [examples/streaming_transducer](./examples/streaming_transducer)
3538

3639
## Setup Environment and Datasets
3740

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Streaming End-to-end Speech Recognition For Mobile Devices
2+
3+
Reference: [https://arxiv.org/abs/1811.06621](https://arxiv.org/abs/1811.06621)
4+
5+
## Example Model YAML Config
6+
7+
```yaml
8+
speech_config:
9+
sample_rate: 16000
10+
frame_ms: 25
11+
stride_ms: 10
12+
feature_type: log_mel_spectrogram
13+
num_feature_bins: 80
14+
preemphasis: 0.97
15+
normalize_signal: True
16+
normalize_feature: True
17+
normalize_per_feature: False
18+
19+
decoder_config:
20+
vocabulary: null
21+
target_vocab_size: 1024
22+
max_subword_length: 4
23+
blank_at_zero: True
24+
beam_width: 5
25+
norm_score: True
26+
27+
model_config:
28+
name: streaming_transducer
29+
subsampling:
30+
type: time_reduction
31+
factor: 3
32+
encoder_dim: 320
33+
encoder_units: 1024
34+
encoder_layers: 7
35+
encoder_layer_norm: True
36+
encoder_type: lstm
37+
embed_dim: 320
38+
embed_dropout: 0.1
39+
num_rnns: 1
40+
rnn_units: 320
41+
rnn_type: lstm
42+
layer_norm: True
43+
joint_dim: 320
44+
45+
learning_config:
46+
augmentations:
47+
after:
48+
time_masking:
49+
num_masks: 10
50+
mask_factor: 100
51+
p_upperbound: 0.2
52+
freq_masking:
53+
num_masks: 1
54+
mask_factor: 27
55+
56+
dataset_config:
57+
train_paths: ...
58+
eval_paths: ...
59+
test_paths: ...
60+
tfrecords_dir: ...
61+
62+
running_config:
63+
batch_size: 4
64+
num_epochs: 22
65+
outdir: ...
66+
log_interval_steps: 400
67+
save_interval_steps: 400
68+
eval_interval_steps: 1000
69+
```
70+
71+
## Usage
72+
73+
Training, see `python examples/streamingTransducer/train_streaming_transducer.py --help`
74+
75+
Testing, see `python examples/streamingTransducer/train_streaming_transducer.py --help`
76+
77+
TFLite Conversion, see `python examples/streamingTransducer/tflite_streaming_transducer.py --help`
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2020 Huy Le Nguyen (@usimarit)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
speech_config:
16+
sample_rate: 16000
17+
frame_ms: 25
18+
stride_ms: 10
19+
num_feature_bins: 80
20+
feature_type: log_mel_spectrogram
21+
preemphasis: 0.97
22+
normalize_signal: True
23+
normalize_feature: True
24+
normalize_per_feature: False
25+
26+
decoder_config:
27+
vocabulary: null
28+
target_vocab_size: 1024
29+
max_subword_length: 4
30+
blank_at_zero: True
31+
beam_width: 5
32+
norm_score: True
33+
34+
model_config:
35+
name: streaming_transducer
36+
reduction_factor: 2
37+
reduction_positions: [1]
38+
encoder_dim: 320
39+
encoder_units: 1024
40+
encoder_layers: 8
41+
encoder_layer_norm: True
42+
encoder_type: lstm
43+
embed_dim: 320
44+
embed_dropout: 0.1
45+
num_rnns: 1
46+
rnn_units: 320
47+
rnn_type: lstm
48+
layer_norm: True
49+
joint_dim: 320
50+
51+
learning_config:
52+
augmentations:
53+
after:
54+
time_masking:
55+
num_masks: 10
56+
mask_factor: 100
57+
p_upperbound: 0.05
58+
freq_masking:
59+
num_masks: 1
60+
mask_factor: 27
61+
62+
dataset_config:
63+
train_paths:
64+
- /mnt/Data/ML/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
65+
eval_paths:
66+
- /mnt/Data/ML/ASR/Raw/LibriSpeech/dev-clean/transcripts.tsv
67+
- /mnt/Data/ML/ASR/Raw/LibriSpeech/dev-other/transcripts.tsv
68+
test_paths:
69+
- /mnt/Data/ML/ASR/Raw/LibriSpeech/test-clean/transcripts.tsv
70+
tfrecords_dir: null
71+
72+
running_config:
73+
batch_size: 2
74+
accumulation_steps: 1
75+
num_epochs: 20
76+
outdir: /mnt/Projects/asrk16/trained/local/librispeech/streaming_transducer
77+
log_interval_steps: 300
78+
eval_interval_steps: 500
79+
save_interval_steps: 1000
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2020 Huy Le Nguyen (@usimarit)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import argparse
17+
from tensorflow_asr.utils import setup_environment, setup_devices
18+
19+
setup_environment()
20+
import tensorflow as tf
21+
22+
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
23+
24+
tf.keras.backend.clear_session()
25+
26+
parser = argparse.ArgumentParser(prog="Conformer Testing")
27+
28+
parser.add_argument("--config", type=str, default=DEFAULT_YAML,
29+
help="The file path of model configuration file")
30+
31+
parser.add_argument("--saved", type=str, default=None,
32+
help="Path to saved model")
33+
34+
parser.add_argument("--tfrecords", default=False, action="store_true",
35+
help="Whether to use tfrecords as dataset")
36+
37+
parser.add_argument("--mxp", default=False, action="store_true",
38+
help="Enable mixed precision")
39+
40+
parser.add_argument("--device", type=int, default=0,
41+
help="Device's id to run test on")
42+
43+
parser.add_argument("--cpu", default=False, action="store_true",
44+
help="Whether to only use cpu")
45+
46+
parser.add_argument("--output_name", type=str, default="test",
47+
help="Result filename name prefix")
48+
49+
args = parser.parse_args()
50+
51+
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp})
52+
53+
setup_devices([args.device], cpu=args.cpu)
54+
55+
from tensorflow_asr.configs.user_config import UserConfig
56+
from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset
57+
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
58+
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
59+
from tensorflow_asr.runners.base_runners import BaseTester
60+
from tensorflow_asr.models.streaming_transducer import StreamingTransducer
61+
62+
config = UserConfig(DEFAULT_YAML, args.config, learning=True)
63+
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
64+
text_featurizer = CharFeaturizer(config["decoder_config"])
65+
66+
tf.random.set_seed(0)
67+
assert args.saved
68+
69+
if args.tfrecords:
70+
test_dataset = ASRTFRecordDataset(
71+
data_paths=config["learning_config"]["dataset_config"]["test_paths"],
72+
tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"],
73+
speech_featurizer=speech_featurizer,
74+
text_featurizer=text_featurizer,
75+
stage="test", shuffle=False
76+
)
77+
else:
78+
test_dataset = ASRSliceDataset(
79+
data_paths=config["learning_config"]["dataset_config"]["test_paths"],
80+
speech_featurizer=speech_featurizer,
81+
text_featurizer=text_featurizer,
82+
stage="test", shuffle=False
83+
)
84+
85+
# build model
86+
streaming_transducer = StreamingTransducer(
87+
vocabulary_size=text_featurizer.num_classes,
88+
**config["model_config"]
89+
)
90+
streaming_transducer._build(speech_featurizer.shape)
91+
streaming_transducer.load_weights(args.saved, by_name=True)
92+
streaming_transducer.summary(line_length=150)
93+
streaming_transducer.add_featurizers(speech_featurizer, text_featurizer)
94+
95+
streaming_transducer_tester = BaseTester(
96+
config=config["learning_config"]["running_config"],
97+
output_name=args.output_name
98+
)
99+
streaming_transducer_tester.compile(streaming_transducer)
100+
streaming_transducer_tester.run(test_dataset)
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright 2020 Huy Le Nguyen (@usimarit)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import argparse
17+
from tensorflow_asr.utils import setup_environment, setup_devices
18+
19+
setup_environment()
20+
import tensorflow as tf
21+
22+
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
23+
24+
tf.keras.backend.clear_session()
25+
26+
parser = argparse.ArgumentParser(prog="Conformer Testing")
27+
28+
parser.add_argument("--config", type=str, default=DEFAULT_YAML,
29+
help="The file path of model configuration file")
30+
31+
parser.add_argument("--saved", type=str, default=None,
32+
help="Path to saved model")
33+
34+
parser.add_argument("--tfrecords", default=False, action="store_true",
35+
help="Whether to use tfrecords as dataset")
36+
37+
parser.add_argument("--mxp", default=False, action="store_true",
38+
help="Enable mixed precision")
39+
40+
parser.add_argument("--device", type=int, default=0,
41+
help="Device's id to run test on")
42+
43+
parser.add_argument("--cpu", default=False, action="store_true",
44+
help="Whether to only use cpu")
45+
46+
parser.add_argument("--subwords", type=str, default=None,
47+
help="Path to file that stores generated subwords")
48+
49+
parser.add_argument("--output_name", type=str, default="test",
50+
help="Result filename name prefix")
51+
52+
args = parser.parse_args()
53+
54+
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp})
55+
56+
setup_devices([args.device], cpu=args.cpu)
57+
58+
from tensorflow_asr.configs.user_config import UserConfig
59+
from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset, ASRSliceDataset
60+
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
61+
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer
62+
from tensorflow_asr.runners.base_runners import BaseTester
63+
from tensorflow_asr.models.streaming_transducer import StreamingTransducer
64+
65+
config = UserConfig(DEFAULT_YAML, args.config, learning=True)
66+
speech_featurizer = TFSpeechFeaturizer(config["speech_config"])
67+
68+
if args.subwords and os.path.exists(args.subwords):
69+
print("Loading subwords ...")
70+
text_featurizer = SubwordFeaturizer.load_from_file(config["decoder_config"], args.subwords)
71+
else:
72+
raise ValueError("subwords must be set")
73+
74+
tf.random.set_seed(0)
75+
assert args.saved
76+
77+
if args.tfrecords:
78+
test_dataset = ASRTFRecordDataset(
79+
data_paths=config["learning_config"]["dataset_config"]["test_paths"],
80+
tfrecords_dir=config["learning_config"]["dataset_config"]["tfrecords_dir"],
81+
speech_featurizer=speech_featurizer,
82+
text_featurizer=text_featurizer,
83+
stage="test", shuffle=False
84+
)
85+
else:
86+
test_dataset = ASRSliceDataset(
87+
data_paths=config["learning_config"]["dataset_config"]["test_paths"],
88+
speech_featurizer=speech_featurizer,
89+
text_featurizer=text_featurizer,
90+
stage="test", shuffle=False
91+
)
92+
93+
# build model
94+
streaming_transducer = StreamingTransducer(
95+
vocabulary_size=text_featurizer.num_classes,
96+
**config["model_config"]
97+
)
98+
streaming_transducer._build(speech_featurizer.shape)
99+
streaming_transducer.load_weights(args.saved, by_name=True)
100+
streaming_transducer.summary(line_length=150)
101+
streaming_transducer.add_featurizers(speech_featurizer, text_featurizer)
102+
103+
streaming_transducer_tester = BaseTester(
104+
config=config["learning_config"]["running_config"],
105+
output_name=args.output_name
106+
)
107+
streaming_transducer_tester.compile(streaming_transducer)
108+
streaming_transducer_tester.run(test_dataset)

0 commit comments

Comments
 (0)