Skip to content

Commit d86d621

Browse files
committed
✍️ update conformer training script
1 parent ccfd924 commit d86d621

File tree

17 files changed

+161
-347
lines changed

17 files changed

+161
-347
lines changed

examples/conformer/config.yml

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@ speech_config:
2424
normalize_per_feature: False
2525

2626
decoder_config:
27-
vocabulary: null
27+
vocabulary: ./vocabularies/librispeech/librispeech_train_10_1008.subwords
2828
target_vocab_size: 1000
2929
max_subword_length: 10
3030
blank_at_zero: True
31-
beam_width: 5
31+
beam_width: 0
3232
norm_score: True
3333
corpus_files:
34-
- /media/nlhuy/Data/ML/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
34+
- /mnt/h/ML/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
3535

3636
model_config:
3737
name: conformer
@@ -40,7 +40,7 @@ model_config:
4040
filters: 144
4141
kernel_size: 3
4242
strides: 2
43-
encoder_positional_encoding: sinusoid_concat_v2
43+
encoder_positional_encoding: sinusoid_concat
4444
encoder_dmodel: 144
4545
encoder_num_blocks: 16
4646
encoder_head_size: 36
@@ -75,19 +75,18 @@ learning_config:
7575
num_masks: 1
7676
mask_factor: 27
7777
data_paths:
78-
- /mnt/Data/ML/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
79-
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
78+
- /mnt/h/ML/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
79+
tfrecords_dir: null
8080
shuffle: True
8181
cache: True
82-
cache_percent: 0.2
8382
buffer_size: 100
8483
drop_remainder: True
8584
stage: train
8685

8786
eval_dataset_config:
8887
use_tf: True
8988
data_paths: null
90-
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
89+
tfrecords_dir: null
9190
shuffle: False
9291
cache: True
9392
buffer_size: 100
@@ -97,7 +96,7 @@ learning_config:
9796
test_dataset_config:
9897
use_tf: True
9998
data_paths: null
100-
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
99+
tfrecords_dir: null
101100
shuffle: False
102101
cache: True
103102
buffer_size: 100
@@ -106,26 +105,21 @@ learning_config:
106105

107106
optimizer_config:
108107
warmup_steps: 40000
109-
beta1: 0.9
110-
beta2: 0.98
108+
beta_1: 0.9
109+
beta_2: 0.98
111110
epsilon: 1e-9
112111

113112
running_config:
114113
batch_size: 2
115-
accumulation_steps: 4
116114
num_epochs: 50
117-
outdir: /mnt/Miscellanea/Models/local/conformer
118-
log_interval_steps: 300
119-
eval_interval_steps: 500
120-
save_interval_steps: 1000
121115
checkpoint:
122-
filepath: /mnt/Miscellanea/Models/local/conformer/checkpoints/{epoch:02d}.h5
116+
filepath: /mnt/e/Models/local/conformer/checkpoints/{epoch:02d}.h5
123117
save_best_only: True
124118
save_weights_only: False
125119
save_freq: epoch
126-
states_dir: /mnt/Miscellanea/Models/local/conformer/states
120+
states_dir: /mnt/e/Models/local/conformer/states
127121
tensorboard:
128-
log_dir: /mnt/Miscellanea/Models/local/conformer/tensorboard
122+
log_dir: /mnt/e/Models/local/conformer/tensorboard
129123
histogram_freq: 1
130124
write_graph: True
131125
write_images: True

examples/conformer/tflite.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414

1515
import os
1616
import argparse
17-
from tensorflow_asr.utils import setup_environment
17+
from tensorflow_asr.utils import env_util, file_util
1818

19-
setup_environment()
19+
env_util.setup_environment()
2020
import tensorflow as tf
2121

2222
from tensorflow_asr.configs.config import Config
2323
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
24-
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer
24+
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, CharFeaturizer
2525
from tensorflow_asr.models.conformer import Conformer
2626

2727
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
@@ -30,17 +30,13 @@
3030

3131
parser = argparse.ArgumentParser(prog="Conformer Testing")
3232

33-
parser.add_argument("--config", type=str, default=DEFAULT_YAML,
34-
help="The file path of model configuration file")
33+
parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")
3534

36-
parser.add_argument("--saved", type=str, default=None,
37-
help="Path to saved model")
35+
parser.add_argument("--saved", type=str, default=None, help="Path to saved model")
3836

39-
parser.add_argument("--subwords", type=str, default=None,
40-
help="Path to file that stores generated subwords")
37+
parser.add_argument("--subwords", type=str, default=None, help="Use subwords")
4138

42-
parser.add_argument("output", type=str, default=None,
43-
help="TFLite file path to be exported")
39+
parser.add_argument("output", type=str, default=None, help="TFLite file path to be exported")
4440

4541
args = parser.parse_args()
4642

@@ -49,17 +45,16 @@
4945
config = Config(args.config)
5046
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
5147

52-
if args.subwords and os.path.exists(args.subwords):
53-
print("Loading subwords ...")
54-
text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords)
48+
if args.subwords:
49+
text_featurizer = SubwordFeaturizer(config.decoder_config)
5550
else:
56-
raise ValueError("subwords must be set")
51+
text_featurizer = CharFeaturizer(config.decoder_config)
5752

5853
# build model
5954
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
6055
conformer._build(speech_featurizer.shape)
6156
conformer.load_weights(args.saved)
62-
conformer.summary(line_length=150)
57+
conformer.summary(line_length=100)
6358
conformer.add_featurizers(speech_featurizer, text_featurizer)
6459

6560
concrete_func = conformer.make_tflite_function().get_concrete_function()
@@ -69,7 +64,6 @@
6964
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
7065
tflite_model = converter.convert()
7166

72-
if not os.path.exists(os.path.dirname(args.output)):
73-
os.makedirs(os.path.dirname(args.output))
67+
args.output = file_util.preprocess_paths(args.output)
7468
with open(args.output, "wb") as tflite_out:
7569
tflite_out.write(tflite_model)

examples/conformer/train.py

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
import os
1616
import math
1717
import argparse
18-
from tensorflow_asr.utils import setup_environment, setup_strategy
18+
from tensorflow_asr.utils import env_util
1919

20-
setup_environment()
20+
env_util.setup_environment()
2121
import tensorflow as tf
2222

2323
DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")
@@ -28,81 +28,86 @@
2828

2929
parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file")
3030

31-
parser.add_argument("--max_ckpts", type=int, default=10, help="Max number of checkpoints to keep")
32-
3331
parser.add_argument("--tfrecords", default=False, action="store_true", help="Whether to use tfrecords")
3432

3533
parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model")
3634

35+
parser.add_argument("--subwords", default=False, action="store_true", help="Use subwords")
36+
3737
parser.add_argument("--tbs", type=int, default=None, help="Train batch size per replica")
3838

3939
parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica")
4040

4141
parser.add_argument("--spx", type=int, default=1, help="Steps per execution for maximizing performance")
4242

43-
parser.add_argument("--metadata_prefix", type=str, default=None, help="Path to file containing metadata")
43+
parser.add_argument("--metadata", type=str, default=None, help="Path to file containing metadata")
44+
45+
parser.add_argument("--static_length", default=False, action="store_true", help="Use static lengths")
4446

4547
parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training")
4648

4749
parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision")
4850

49-
parser.add_argument("--subwords", default=False, action="store_true", help="Use subwords")
50-
5151
args = parser.parse_args()
5252

5353
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp})
5454

55-
strategy = setup_strategy(args.devices)
55+
strategy = env_util.setup_strategy(args.devices)
5656

5757
from tensorflow_asr.configs.config import Config
58-
from tensorflow_asr.datasets.keras import ASRTFRecordDatasetKeras, ASRSliceDatasetKeras
59-
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
60-
from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer, SentencePieceFeaturizer, CharFeaturizer
61-
from tensorflow_asr.models.keras.conformer import Conformer
58+
from tensorflow_asr.datasets import asr_dataset
59+
from tensorflow_asr.featurizers import speech_featurizers, text_featurizers
60+
from tensorflow_asr.models.transducer.conformer import Conformer
6261
from tensorflow_asr.optimizers.schedules import TransformerSchedule
6362

6463
config = Config(args.config)
65-
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
64+
speech_featurizer = speech_featurizers.TFSpeechFeaturizer(config.speech_config)
6665

6766
if args.sentence_piece:
6867
print("Loading SentencePiece model ...")
69-
text_featurizer = SentencePieceFeaturizer(config.decoder_config)
68+
text_featurizer = text_featurizers.SentencePieceFeaturizer(config.decoder_config)
7069
elif args.subwords:
7170
print("Loading subwords ...")
72-
text_featurizer = SubwordFeaturizer(config.decoder_config)
71+
text_featurizer = text_featurizers.SubwordFeaturizer(config.decoder_config)
7372
else:
7473
print("Use characters ...")
75-
text_featurizer = CharFeaturizer(config.decoder_config)
74+
text_featurizer = text_featurizers.CharFeaturizer(config.decoder_config)
7675

7776
if args.tfrecords:
78-
train_dataset = ASRTFRecordDatasetKeras(
79-
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
77+
train_dataset = asr_dataset.ASRTFRecordDataset(
78+
speech_featurizer=speech_featurizer,
79+
text_featurizer=text_featurizer,
8080
**vars(config.learning_config.train_dataset_config),
8181
indefinite=True
8282
)
83-
eval_dataset = ASRTFRecordDatasetKeras(
84-
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
85-
**vars(config.learning_config.eval_dataset_config)
83+
eval_dataset = asr_dataset.ASRTFRecordDataset(
84+
speech_featurizer=speech_featurizer,
85+
text_featurizer=text_featurizer,
86+
**vars(config.learning_config.eval_dataset_config),
87+
indefinite=True
8688
)
87-
# Update metadata calculated from both train and eval datasets
88-
train_dataset.load_metadata(args.metadata_prefix)
89-
eval_dataset.load_metadata(args.metadata_prefix)
90-
# Use dynamic length
91-
speech_featurizer.reset_length()
92-
text_featurizer.reset_length()
9389
else:
94-
train_dataset = ASRSliceDatasetKeras(
95-
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
90+
train_dataset = asr_dataset.ASRSliceDataset(
91+
speech_featurizer=speech_featurizer,
92+
text_featurizer=text_featurizer,
9693
**vars(config.learning_config.train_dataset_config),
9794
indefinite=True
9895
)
99-
eval_dataset = ASRSliceDatasetKeras(
100-
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
101-
**vars(config.learning_config.train_dataset_config),
96+
eval_dataset = asr_dataset.ASRSliceDataset(
97+
speech_featurizer=speech_featurizer,
98+
text_featurizer=text_featurizer,
99+
**vars(config.learning_config.eval_dataset_config),
102100
indefinite=True
103101
)
104102

105-
global_batch_size = config.learning_config.running_config.batch_size
103+
train_dataset.load_metadata(args.metadata)
104+
eval_dataset.load_metadata(args.metadata)
105+
106+
if not args.static_length:
107+
speech_featurizer.reset_length()
108+
text_featurizer.reset_length()
109+
110+
global_batch_size = args.tbs or config.learning_config.running_config.batch_size
106111
global_batch_size *= strategy.num_replicas_in_sync
107112

108113
train_data_loader = train_dataset.create(global_batch_size)
@@ -112,17 +117,15 @@
112117
# build model
113118
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
114119
conformer._build(speech_featurizer.shape)
115-
conformer.summary(line_length=120)
120+
conformer.summary(line_length=100)
116121

117122
optimizer = tf.keras.optimizers.Adam(
118123
TransformerSchedule(
119124
d_model=conformer.dmodel,
120-
warmup_steps=config.learning_config.optimizer_config["warmup_steps"],
125+
warmup_steps=config.learning_config.optimizer_config.pop("warmup_steps", 10000),
121126
max_lr=(0.05 / math.sqrt(conformer.dmodel))
122127
),
123-
beta_1=config.learning_config.optimizer_config["beta1"],
124-
beta_2=config.learning_config.optimizer_config["beta2"],
125-
epsilon=config.learning_config.optimizer_config["epsilon"]
128+
**config.learning_config.optimizer_config
126129
)
127130

128131
conformer.compile(
@@ -139,7 +142,10 @@
139142
]
140143

141144
conformer.fit(
142-
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
143-
validation_data=eval_data_loader, callbacks=callbacks,
144-
steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps
145+
train_data_loader,
146+
epochs=config.learning_config.running_config.num_epochs,
147+
validation_data=eval_data_loader,
148+
callbacks=callbacks,
149+
steps_per_epoch=train_dataset.total_steps,
150+
validation_steps=eval_dataset.total_steps
145151
)

0 commit comments

Comments
 (0)