Skip to content

Commit bda1fdf

Browse files
authored
Merge pull request #146 from TensorSpeech/dev/tpu
Support TPU and static shape training
2 parents d8a130a + 4414d81 commit bda1fdf

26 files changed

+511
-87
lines changed

examples/conformer/config.yml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@ decoder_config:
3131
beam_width: 5
3232
norm_score: True
3333
corpus_files:
34-
- /media/nlhuy/Data/ML/ASR/Raw/LibriSpeech/LibriSpeech/train-clean-100/transcripts.tsv
35-
- /media/nlhuy/Data/ML/ASR/Raw/LibriSpeech/LibriSpeech/train-clean-360/transcripts.tsv
36-
- /media/nlhuy/Data/ML/ASR/Raw/LibriSpeech/LibriSpeech/train-other-500/transcripts.tsv
34+
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv
3735

3836
model_config:
3937
name: conformer
@@ -77,32 +75,35 @@ learning_config:
7775
mask_factor: 27
7876
data_paths:
7977
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv
80-
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords-test
78+
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
8179
shuffle: True
8280
cache: True
8381
buffer_size: 100
8482
drop_remainder: True
83+
stage: train
8584

8685
eval_dataset_config:
8786
use_tf: True
8887
data_paths:
8988
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-clean/transcripts.tsv
9089
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv
91-
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords-test
90+
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
9291
shuffle: False
9392
cache: True
9493
buffer_size: 100
9594
drop_remainder: True
95+
stage: eval
9696

9797
test_dataset_config:
9898
use_tf: True
9999
data_paths:
100100
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv
101-
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords-test
101+
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
102102
shuffle: False
103103
cache: True
104104
buffer_size: 100
105105
drop_remainder: True
106+
stage: test
106107

107108
optimizer_config:
108109
warmup_steps: 40000
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright 2020 Huy Le Nguyen (@usimarit)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import math
17+
import argparse
18+
from tensorflow_asr.utils import setup_environment, setup_tpu
19+
20+
setup_environment()
21+
import tensorflow as tf
22+
23+
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
24+
25+
tf.keras.backend.clear_session()
26+
27+
parser = argparse.ArgumentParser(prog="Conformer Training")
28+
29+
parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")
30+
31+
parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model")
32+
33+
parser.add_argument("--bs", type=int, default=None, help="Batch size per replica")
34+
35+
parser.add_argument("--tpu_address", type=str, default=None, help="TPU address. Leave None on Colab")
36+
37+
parser.add_argument("--max_lengths_prefix", type=str, default=None, help="Path to file containing max lengths")
38+
39+
parser.add_argument("--compute_lengths", default=False, action="store_true", help="Whether to compute lengths")
40+
41+
parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision")
42+
43+
parser.add_argument("--subwords", type=str, default=None, help="Path to file that stores generated subwords")
44+
45+
parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[], help="Transcript files for generating subwords")
46+
47+
args = parser.parse_args()
48+
49+
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp})
50+
51+
strategy = setup_tpu(args.tpu_address)
52+
53+
from tensorflow_asr.configs.config import Config
54+
from tensorflow_asr.datasets.keras import ASRTFRecordDatasetKeras
55+
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
56+
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, SentencePieceFeaturizer
57+
from tensorflow_asr.models.keras.conformer import Conformer
58+
from tensorflow_asr.optimizers.schedules import TransformerSchedule
59+
60+
config = Config(args.config)
61+
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
62+
63+
if args.sentence_piece:
64+
print("Loading SentencePiece model ...")
65+
text_featurizer = SentencePieceFeaturizer.load_from_file(config.decoder_config, args.subwords)
66+
elif args.subwords and os.path.exists(args.subwords):
67+
print("Loading subwords ...")
68+
text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords)
69+
else:
70+
print("Generating subwords ...")
71+
text_featurizer = SubwordFeaturizer.build_from_corpus(
72+
config.decoder_config,
73+
corpus_files=args.subwords_corpus
74+
)
75+
text_featurizer.save_to_file(args.subwords)
76+
77+
train_dataset = ASRTFRecordDatasetKeras(
78+
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
79+
**vars(config.learning_config.train_dataset_config)
80+
)
81+
eval_dataset = ASRTFRecordDatasetKeras(
82+
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
83+
**vars(config.learning_config.eval_dataset_config)
84+
)
85+
86+
if args.compute_lengths:
87+
train_dataset.update_lengths(args.max_lengths_prefix)
88+
eval_dataset.update_lengths(args.max_lengths_prefix)
89+
90+
# Update max lengths calculated from both train and eval datasets
91+
train_dataset.load_max_lengths(args.max_lengths_prefix)
92+
eval_dataset.load_max_lengths(args.max_lengths_prefix)
93+
94+
with strategy.scope():
95+
batch_size = args.bs if args.bs is not None else config.learning_config.running_config.batch_size
96+
global_batch_size = batch_size
97+
global_batch_size *= strategy.num_replicas_in_sync
98+
# build model
99+
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
100+
conformer._build(speech_featurizer.shape, prediction_shape=text_featurizer.prepand_shape, batch_size=global_batch_size)
101+
conformer.summary(line_length=120)
102+
103+
optimizer = tf.keras.optimizers.Adam(
104+
TransformerSchedule(
105+
d_model=conformer.dmodel,
106+
warmup_steps=config.learning_config.optimizer_config["warmup_steps"],
107+
max_lr=(0.05 / math.sqrt(conformer.dmodel))
108+
),
109+
beta_1=config.learning_config.optimizer_config["beta1"],
110+
beta_2=config.learning_config.optimizer_config["beta2"],
111+
epsilon=config.learning_config.optimizer_config["epsilon"]
112+
)
113+
114+
conformer.compile(optimizer=optimizer, global_batch_size=global_batch_size, blank=text_featurizer.blank)
115+
116+
train_data_loader = train_dataset.create(global_batch_size)
117+
eval_data_loader = eval_dataset.create(global_batch_size)
118+
119+
callbacks = [
120+
tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),
121+
tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),
122+
tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)
123+
]
124+
125+
conformer.fit(
126+
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
127+
validation_data=eval_data_loader, callbacks=callbacks,
128+
)
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright 2021 M. Yusuf Sarıgöz (@monatis) and Huy Le Nguyen (@usimarit)
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import math
17+
import argparse
18+
from tensorflow_asr.utils import setup_environment, setup_tpu
19+
20+
setup_environment()
21+
import tensorflow as tf
22+
23+
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
24+
25+
tf.keras.backend.clear_session()
26+
27+
parser = argparse.ArgumentParser(prog="Conformer Training")
28+
29+
parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")
30+
31+
parser.add_argument("--max_ckpts", type=int, default=10, help="Max number of checkpoints to keep")
32+
33+
parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model")
34+
35+
parser.add_argument("--bs", type=int, default=None, help="Common training and evaluation batch size per TPU core")
36+
37+
parser.add_argument("--tpu_address", type=str, default=None, help="TPU address. Leave None on Colab")
38+
39+
parser.add_argument("--max_lengths_prefix", type=str, default=None, help="Path to file containing max lengths")
40+
41+
parser.add_argument("--compute_lengths", default=False, action="store_true", help="Whether to compute lengths")
42+
43+
parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision")
44+
45+
parser.add_argument("--subwords", type=str, default=None, help="Path to file that stores generated subwords")
46+
47+
parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[], help="Transcript files for generating subwords")
48+
49+
args = parser.parse_args()
50+
51+
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp})
52+
53+
from tensorflow_asr.configs.config import Config
54+
from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset
55+
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
56+
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, SentencePieceFeaturizer
57+
from tensorflow_asr.runners.transducer_runners import TransducerTrainer
58+
from tensorflow_asr.models.conformer import Conformer
59+
from tensorflow_asr.optimizers.schedules import TransformerSchedule
60+
61+
config = Config(args.config, learning=True)
62+
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
63+
64+
if args.sentence_piece:
65+
print("Loading SentencePiece model ...")
66+
text_featurizer = SentencePieceFeaturizer.load_from_file(config.decoder_config, args.subwords)
67+
elif args.subwords and os.path.exists(args.subwords):
68+
print("Loading subwords ...")
69+
text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords)
70+
else:
71+
print("Generating subwords ...")
72+
text_featurizer = SubwordFeaturizer.build_from_corpus(
73+
config.decoder_config,
74+
corpus_files=args.subwords_corpus
75+
)
76+
text_featurizer.save_to_file(args.subwords)
77+
78+
train_dataset = ASRTFRecordDataset(
79+
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
80+
**vars(config.learning_config.train_dataset_config)
81+
)
82+
83+
eval_dataset = ASRTFRecordDataset(
84+
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
85+
**vars(config.learning_config.eval_dataset_config)
86+
)
87+
88+
if args.compute_lengths:
89+
train_dataset.update_lengths(args.max_lengths_prefix)
90+
eval_dataset.update_lengths(args.max_lengths_prefix)
91+
92+
# Update max lengths calculated from both train and eval datasets
93+
train_dataset.load_max_lengths(args.max_lengths_prefix)
94+
eval_dataset.load_max_lengths(args.max_lengths_prefix)
95+
96+
strategy = setup_tpu(args.tpu_address)
97+
98+
conformer_trainer = TransducerTrainer(
99+
config=config.learning_config.running_config,
100+
text_featurizer=text_featurizer, strategy=strategy
101+
)
102+
103+
with conformer_trainer.strategy.scope():
104+
# build model
105+
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
106+
conformer._build(speech_featurizer.shape, prediction_shape=text_featurizer.prepand_shape,
107+
batch_size=args.bs if args.bs is not None else config.learning_config.running_config.batch_size)
108+
conformer.summary(line_length=120)
109+
110+
optimizer = tf.keras.optimizers.Adam(
111+
TransformerSchedule(
112+
d_model=conformer.dmodel,
113+
warmup_steps=config.learning_config.optimizer_config["warmup_steps"],
114+
max_lr=(0.05 / math.sqrt(conformer.dmodel))
115+
),
116+
beta_1=config.learning_config.optimizer_config["beta1"],
117+
beta_2=config.learning_config.optimizer_config["beta2"],
118+
epsilon=config.learning_config.optimizer_config["epsilon"]
119+
)
120+
121+
conformer_trainer.compile(model=conformer, optimizer=optimizer, max_to_keep=args.max_ckpts)
122+
123+
conformer_trainer.fit(train_dataset, eval_dataset, train_bs=args.bs, eval_bs=args.bs)

examples/contextnet/config.yml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,32 +208,35 @@ learning_config:
208208
mask_factor: 27
209209
data_paths:
210210
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv
211-
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords-test
211+
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
212212
shuffle: True
213213
cache: True
214214
buffer_size: 100
215215
drop_remainder: True
216+
stage: train
216217

217218
eval_dataset_config:
218219
use_tf: True
219220
data_paths:
220221
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-clean/transcripts.tsv
221222
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv
222-
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords-test
223+
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
223224
shuffle: False
224225
cache: True
225226
buffer_size: 100
226227
drop_remainder: True
228+
stage: eval
227229

228230
test_dataset_config:
229231
use_tf: True
230232
data_paths:
231233
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv
232-
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords-test
234+
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
233235
shuffle: False
234236
cache: True
235237
buffer_size: 100
236238
drop_remainder: True
239+
stage: test
237240

238241
optimizer_config:
239242
warmup_steps: 40000

examples/deepspeech2/config.yml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,32 +53,35 @@ learning_config:
5353
use_tf: True
5454
data_paths:
5555
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv
56-
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords-test
56+
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
5757
shuffle: True
5858
cache: True
5959
buffer_size: 100
6060
drop_remainder: True
61+
stage: train
6162

6263
eval_dataset_config:
6364
use_tf: True
6465
data_paths:
6566
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-clean/transcripts.tsv
6667
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv
67-
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords-test
68+
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
6869
shuffle: False
6970
cache: True
7071
buffer_size: 100
7172
drop_remainder: True
73+
stage: eval
7274

7375
test_dataset_config:
7476
use_tf: True
7577
data_paths:
7678
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv
77-
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords-test
79+
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
7880
shuffle: False
7981
cache: True
8082
buffer_size: 100
8183
drop_remainder: True
84+
stage: test
8285

8386
optimizer_config:
8487
class_name: adam

0 commit comments

Comments
 (0)