Skip to content

Commit e7d5112

Browse files
committed
⚡ add keras contextnet, streaming transducer and ctc models
1 parent fc338c5 commit e7d5112

File tree

16 files changed

+1270
-5
lines changed

16 files changed

+1270
-5
lines changed

examples/contextnet/config.yml

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ learning_config:
213213
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv
214214
test_paths:
215215
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv
216-
tfrecords_dir: null
216+
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
217217

218218
optimizer_config:
219219
warmup_steps: 40000
@@ -229,3 +229,16 @@ learning_config:
229229
log_interval_steps: 300
230230
eval_interval_steps: 500
231231
save_interval_steps: 1000
232+
checkpoint:
233+
filepath: /mnt/Miscellanea/Models/local/contextnet/checkpoints/{epoch:02d}.h5
234+
save_best_only: True
235+
save_weights_only: False
236+
save_freq: epoch
237+
states_dir: /mnt/Miscellanea/Models/local/contextnet/states
238+
tensorboard:
239+
log_dir: /mnt/Miscellanea/Models/local/contextnet/tensorboard
240+
histogram_freq: 1
241+
write_graph: True
242+
write_images: True
243+
update_freq: 'epoch'
244+
profile_batch: 2
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright 2020 Huy Le Nguyen (@usimarit)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import math
17+
import argparse
18+
from tensorflow_asr.utils import setup_environment, setup_strategy
19+
20+
setup_environment()
21+
import tensorflow as tf
22+
23+
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
24+
25+
tf.keras.backend.clear_session()
26+
27+
parser = argparse.ArgumentParser(prog="ContextNet Training")
28+
29+
parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")
30+
31+
parser.add_argument("--max_ckpts", type=int, default=10, help="Max number of checkpoints to keep")
32+
33+
parser.add_argument("--tfrecords", default=False, action="store_true", help="Whether to use tfrecords")
34+
35+
parser.add_argument("--tfrecords_shards", type=int, default=16, help="Number of tfrecords shards")
36+
37+
parser.add_argument("--tbs", type=int, default=None, help="Train batch size per replica")
38+
39+
parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica")
40+
41+
parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training")
42+
43+
parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision")
44+
45+
parser.add_argument("--cache", default=False, action="store_true", help="Enable caching for dataset")
46+
47+
parser.add_argument("--subwords", type=str, default=None, help="Path to file that stores generated subwords")
48+
49+
parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[], help="Transcript files for generating subwords")
50+
51+
parser.add_argument("--bfs", type=int, default=100, help="Buffer size for shuffling")
52+
53+
args = parser.parse_args()
54+
55+
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp})
56+
57+
strategy = setup_strategy(args.devices)
58+
59+
from tensorflow_asr.configs.config import Config
60+
from tensorflow_asr.datasets.keras import ASRTFRecordDatasetKeras, ASRSliceDatasetKeras
61+
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
62+
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer
63+
from tensorflow_asr.models.keras.contextnet import ContextNet
64+
from tensorflow_asr.optimizers.schedules import TransformerSchedule
65+
66+
config = Config(args.config)
67+
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
68+
69+
if args.subwords and os.path.exists(args.subwords):
70+
print("Loading subwords ...")
71+
text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords)
72+
else:
73+
print("Generating subwords ...")
74+
text_featurizer = SubwordFeaturizer.build_from_corpus(
75+
config.decoder_config,
76+
corpus_files=args.subwords_corpus
77+
)
78+
text_featurizer.save_to_file(args.subwords)
79+
80+
if args.tfrecords:
81+
train_dataset = ASRTFRecordDatasetKeras(
82+
data_paths=config.learning_config.dataset_config.train_paths,
83+
tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
84+
speech_featurizer=speech_featurizer,
85+
text_featurizer=text_featurizer,
86+
augmentations=config.learning_config.augmentations,
87+
tfrecords_shards=args.tfrecords_shards,
88+
stage="train", cache=args.cache,
89+
shuffle=True, buffer_size=args.bfs,
90+
)
91+
eval_dataset = ASRTFRecordDatasetKeras(
92+
data_paths=config.learning_config.dataset_config.eval_paths,
93+
tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
94+
tfrecords_shards=args.tfrecords_shards,
95+
speech_featurizer=speech_featurizer,
96+
text_featurizer=text_featurizer,
97+
stage="eval", cache=args.cache,
98+
shuffle=True, buffer_size=args.bfs,
99+
)
100+
else:
101+
train_dataset = ASRSliceDatasetKeras(
102+
data_paths=config.learning_config.dataset_config.train_paths,
103+
speech_featurizer=speech_featurizer,
104+
text_featurizer=text_featurizer,
105+
augmentations=config.learning_config.augmentations,
106+
stage="train", cache=args.cache,
107+
shuffle=True, buffer_size=args.bfs,
108+
)
109+
eval_dataset = ASRSliceDatasetKeras(
110+
data_paths=config.learning_config.dataset_config.eval_paths,
111+
speech_featurizer=speech_featurizer,
112+
text_featurizer=text_featurizer,
113+
stage="eval", cache=args.cache,
114+
shuffle=True, buffer_size=args.bfs,
115+
)
116+
117+
with strategy.scope():
118+
global_batch_size = config.learning_config.running_config.batch_size
119+
global_batch_size *= strategy.num_replicas_in_sync
120+
# build model
121+
contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes)
122+
contextnet._build(speech_featurizer.shape)
123+
contextnet.summary(line_length=120)
124+
125+
optimizer = tf.keras.optimizers.Adam(
126+
TransformerSchedule(
127+
d_model=contextnet.dmodel,
128+
warmup_steps=config.learning_config.optimizer_config["warmup_steps"],
129+
max_lr=(0.05 / math.sqrt(contextnet.dmodel))
130+
),
131+
beta_1=config.learning_config.optimizer_config["beta1"],
132+
beta_2=config.learning_config.optimizer_config["beta2"],
133+
epsilon=config.learning_config.optimizer_config["epsilon"]
134+
)
135+
136+
contextnet.compile(optimizer=optimizer, global_batch_size=global_batch_size, blank=text_featurizer.blank)
137+
138+
train_data_loader = train_dataset.create(global_batch_size)
139+
eval_data_loader = eval_dataset.create(global_batch_size)
140+
141+
callbacks = [
142+
tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),
143+
tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),
144+
tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)
145+
]
146+
147+
contextnet.fit(
148+
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
149+
validation_data=eval_data_loader, callbacks=callbacks,
150+
steps_per_epoch=train_dataset.total_steps
151+
)

examples/deepspeech2/config.yml

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ learning_config:
5959
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv
6060
test_paths:
6161
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv
62-
tfrecords_dir: null
62+
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
6363

6464
optimizer_config:
6565
class_name: adam
@@ -74,3 +74,16 @@ learning_config:
7474
log_interval_steps: 400
7575
save_interval_steps: 400
7676
eval_interval_steps: 800
77+
checkpoint:
78+
filepath: /mnt/Miscellanea/Models/local/deepspeech2/checkpoints/{epoch:02d}.h5
79+
save_best_only: True
80+
save_weights_only: False
81+
save_freq: epoch
82+
states_dir: /mnt/Miscellanea/Models/local/deepspeech2/states
83+
tensorboard:
84+
log_dir: /mnt/Miscellanea/Models/local/deepspeech2/tensorboard
85+
histogram_freq: 1
86+
write_graph: True
87+
write_images: True
88+
update_freq: 'epoch'
89+
profile_batch: 2
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright 2020 Huy Le Nguyen (@usimarit)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import argparse
17+
from tensorflow_asr.utils import setup_environment, setup_strategy
18+
19+
setup_environment()
20+
import tensorflow as tf
21+
22+
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
23+
24+
tf.keras.backend.clear_session()
25+
26+
parser = argparse.ArgumentParser(prog="Deep Speech 2 Training")
27+
28+
parser.add_argument("--config", "-c", type=str, default=DEFAULT_YAML,
29+
help="The file path of model configuration file")
30+
31+
parser.add_argument("--max_ckpts", type=int, default=10,
32+
help="Max number of checkpoints to keep")
33+
34+
parser.add_argument("--tbs", type=int, default=None,
35+
help="Train batch size per replicas")
36+
37+
parser.add_argument("--ebs", type=int, default=None,
38+
help="Evaluation batch size per replicas")
39+
40+
parser.add_argument("--tfrecords", default=False, action="store_true",
41+
help="Whether to use tfrecords dataset")
42+
43+
parser.add_argument("--devices", type=int, nargs="*", default=[0],
44+
help="Devices' ids to apply distributed training")
45+
46+
parser.add_argument("--mxp", default=False, action="store_true",
47+
help="Enable mixed precision")
48+
49+
parser.add_argument("--cache", default=False, action="store_true",
50+
help="Enable caching for dataset")
51+
52+
args = parser.parse_args()
53+
54+
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp})
55+
56+
strategy = setup_strategy(args.devices)
57+
58+
from tensorflow_asr.configs.config import Config
59+
from tensorflow_asr.datasets.keras import ASRTFRecordDatasetKeras, ASRSliceDatasetKeras
60+
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
61+
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
62+
from tensorflow_asr.models.keras.deepspeech2 import DeepSpeech2
63+
64+
config = Config(args.config)
65+
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
66+
text_featurizer = CharFeaturizer(config.decoder_config)
67+
68+
if args.tfrecords:
69+
train_dataset = ASRTFRecordDatasetKeras(
70+
data_paths=config.learning_config.dataset_config.train_paths,
71+
tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
72+
speech_featurizer=speech_featurizer,
73+
text_featurizer=text_featurizer,
74+
augmentations=config.learning_config.augmentations,
75+
stage="train", cache=args.cache, shuffle=True
76+
)
77+
eval_dataset = ASRTFRecordDatasetKeras(
78+
data_paths=config.learning_config.dataset_config.eval_paths,
79+
tfrecords_dir=config.learning_config.dataset_config.tfrecords_dir,
80+
speech_featurizer=speech_featurizer,
81+
text_featurizer=text_featurizer,
82+
stage="eval", cache=args.cache, shuffle=True
83+
)
84+
else:
85+
train_dataset = ASRSliceDatasetKeras(
86+
speech_featurizer=speech_featurizer,
87+
text_featurizer=text_featurizer,
88+
data_paths=config.learning_config.dataset_config.train_paths,
89+
augmentations=config.learning_config.augmentations,
90+
stage="train", cache=args.cache, shuffle=True
91+
)
92+
eval_dataset = ASRSliceDatasetKeras(
93+
speech_featurizer=speech_featurizer,
94+
text_featurizer=text_featurizer,
95+
data_paths=config.learning_config.dataset_config.eval_paths,
96+
stage="eval", cache=args.cache, shuffle=True
97+
)
98+
99+
# Build DS2 model
100+
with strategy.scope():
101+
global_batch_size = config.learning_config.running_config.batch_size
102+
global_batch_size *= strategy.num_replicas_in_sync
103+
104+
ds2_model = DeepSpeech2(**config.model_config, vocabulary_size=text_featurizer.num_classes)
105+
ds2_model._build(speech_featurizer.shape)
106+
ds2_model.summary(line_length=120)
107+
108+
ds2_model.compile(optimizer=config.learning_config.optimizer_config,
109+
global_batch_size=global_batch_size, blank=text_featurizer.blank)
110+
111+
train_data_loader = train_dataset.create(global_batch_size)
112+
eval_data_loader = eval_dataset.create(global_batch_size)
113+
114+
callbacks = [
115+
tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),
116+
tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),
117+
tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)
118+
]
119+
120+
ds2_model.fit(
121+
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
122+
validation_data=eval_data_loader, callbacks=callbacks,
123+
steps_per_epoch=train_dataset.total_steps
124+
)

examples/jasper/config.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,16 @@ learning_config:
8181
log_interval_steps: 400
8282
save_interval_steps: 400
8383
eval_interval_steps: 800
84+
checkpoint:
85+
filepath: /mnt/Miscellanea/Models/local/jasper/checkpoints/{epoch:02d}.h5
86+
save_best_only: True
87+
save_weights_only: False
88+
save_freq: epoch
89+
states_dir: /mnt/Miscellanea/Models/local/jasper/states
90+
tensorboard:
91+
log_dir: /mnt/Miscellanea/Models/local/jasper/tensorboard
92+
histogram_freq: 1
93+
write_graph: True
94+
write_images: True
95+
update_freq: 'epoch'
96+
profile_batch: 2

0 commit comments

Comments
 (0)