Skip to content

Commit 112a0c3

Browse files
authored
Merge pull request #151 from TensorSpeech/dev/gcs
Support for save model to cloud and infinite dataset
2 parents f16fee4 + bddcf81 commit 112a0c3

18 files changed

+353
-320
lines changed

README.md

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ TensorFlowASR implements some automatic speech recognition architectures such as
4444
- [TFLite Convertion](#tflite-convertion)
4545
- [Features Extraction](#features-extraction)
4646
- [Augmentations](#augmentations)
47-
- [Training & Testing](#training--testing)
47+
- [Training & Testing Tutorial](#training--testing-tutorial)
4848
- [Corpus Sources and Pretrained Models](#corpus-sources-and-pretrained-models)
4949
- [English](#english)
5050
- [Vietnamese](#vietnamese)
@@ -164,34 +164,17 @@ See [features_extraction](./tensorflow_asr/featurizers/README.md)
164164

165165
See [augmentations](./tensorflow_asr/augmentations/README.md)
166166

167-
## Training & Testing
168-
169-
**Example YAML Config Structure**
170-
171-
```yaml
172-
speech_config: ...
173-
model_config: ...
174-
decoder_config: ...
175-
learning_config:
176-
train_dataset_config:
177-
augmentation_config: ...
178-
data_paths: ...
179-
tfrecords_dir: ...
180-
eval_dataset_config:
181-
augmentation_config: ...
182-
data_paths: ...
183-
tfrecords_dir: ...
184-
test_dataset_config:
185-
augmentation_config: ...
186-
data_paths: ...
187-
tfrecords_dir: ...
188-
optimizer_config: ...
189-
running_config:
190-
batch_size: 8
191-
num_epochs: 20
192-
outdir: ...
193-
log_interval_steps: 500
194-
```
167+
## Training & Testing Tutorial
168+
169+
1. Define config YAML file, see the `config.yml` files in the [example folder](./examples) for reference (you can copy and modify values such as parameters, paths, etc.. to match your local machine configuration)
170+
2. Download your corpus (a.k.a datasets) and create a script to generate `transcripts.tsv` files from your corpus (this is general format used in this project because each dataset has different format). For more detail, see [datasets](./tensorflow_asr/datasets/README.md). **Note:** Make sure your data contain only characters in your language, for example, english has `a` to `z` and `'`. **Do not use `cache` if your dataset size is not fit in the RAM**.
171+
3. [Optional] Generate TFRecords to use `tf.data.TFRecordDataset` for better performance by using the script [create_tfrecords.py](./scripts/create_tfrecords.py)
172+
4. Create vocabulary file (characters or subwords/wordpieces) by defining `language.characters`, using the scripts [generate_vocab_subwords.py](./scripts/generate_vocab_subwords.py) or [generate_vocab_sentencepiece.py](./scripts/generate_vocab_sentencepiece.py). There're predefined ones in [vocabularies](./vocabularies)
173+
5. [Optional] Generate metadata file for your dataset by using script [generate_metadata.py](./scripts/generate_metadata.py). This metadata file contains maximum lengths calculated with your `config.yml` and total number of elements in each dataset, for static shape training and precalculated steps per epoch.
174+
6. For training, see `train_*.py` files in the [example folder](./examples) to see the options
175+
7. For testing, see `test_.*.py` files in the [example folder](./examples) to see the options. **Note:** Testing is currently not supported for TPUs. It will print nothing other than the progress bar in the console, but it will store the predicted transcripts to the file `output_name.tsv` in the `outdir` defined in the config yaml file. After testing is done, the metrics (WER and CER) are calculated from `output_name.tsv`. **If you define the same `output_name`, it will resume the testing from the previous tested batch, which means if the testing is done then it will only calculate the metrics, if you want to run a new test, define a new `output_name` that the file `output.tsv` is not exists or only contains the header**
176+
177+
**Recommendation**: For better performance, please use **keras builtin training functions** as in `train_keras_*.py` files and/or tfrecords. Keras builtin training uses **infinite dataset**, which avoids the potential last partial batch.
195178

196179
See [examples](./examples/) for some predefined ASR models and results
197180

examples/conformer/train_keras_subword_conformer.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@
3838

3939
parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica")
4040

41+
parser.add_argument("--spx", type=int, default=1, help="Steps per execution for maximizing performance")
42+
43+
parser.add_argument("--metadata_prefix", type=str, default=None, help="Path to file containing metadata")
44+
4145
parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training")
4246

4347
parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision")
@@ -79,25 +83,38 @@
7983
if args.tfrecords:
8084
train_dataset = ASRTFRecordDatasetKeras(
8185
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
82-
**vars(config.learning_config.train_dataset_config)
86+
**vars(config.learning_config.train_dataset_config),
87+
indefinite=True
8388
)
8489
eval_dataset = ASRTFRecordDatasetKeras(
8590
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
8691
**vars(config.learning_config.eval_dataset_config)
8792
)
93+
# Update metadata calculated from both train and eval datasets
94+
train_dataset.load_metadata(args.metadata_prefix)
95+
eval_dataset.load_metadata(args.metadata_prefix)
96+
# Use dynamic length
97+
speech_featurizer.reset_length()
98+
text_featurizer.reset_length()
8899
else:
89100
train_dataset = ASRSliceDatasetKeras(
90101
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
91-
**vars(config.learning_config.train_dataset_config)
102+
**vars(config.learning_config.train_dataset_config),
103+
indefinite=True
92104
)
93105
eval_dataset = ASRSliceDatasetKeras(
94106
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
95-
**vars(config.learning_config.train_dataset_config)
107+
**vars(config.learning_config.train_dataset_config),
108+
indefinite=True
96109
)
97110

111+
global_batch_size = config.learning_config.running_config.batch_size
112+
global_batch_size *= strategy.num_replicas_in_sync
113+
114+
train_data_loader = train_dataset.create(global_batch_size)
115+
eval_data_loader = eval_dataset.create(global_batch_size)
116+
98117
with strategy.scope():
99-
global_batch_size = config.learning_config.running_config.batch_size
100-
global_batch_size *= strategy.num_replicas_in_sync
101118
# build model
102119
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
103120
conformer._build(speech_featurizer.shape)
@@ -114,19 +131,21 @@
114131
epsilon=config.learning_config.optimizer_config["epsilon"]
115132
)
116133

117-
conformer.compile(optimizer=optimizer, global_batch_size=global_batch_size, blank=text_featurizer.blank)
118-
119-
train_data_loader = train_dataset.create(global_batch_size)
120-
eval_data_loader = eval_dataset.create(global_batch_size)
121-
122-
callbacks = [
123-
tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),
124-
tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),
125-
tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)
126-
]
127-
128-
conformer.fit(
129-
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
130-
validation_data=eval_data_loader, callbacks=callbacks,
131-
steps_per_epoch=train_dataset.total_steps
134+
conformer.compile(
135+
optimizer=optimizer,
136+
experimental_steps_per_execution=args.spx,
137+
global_batch_size=global_batch_size,
138+
blank=text_featurizer.blank
132139
)
140+
141+
callbacks = [
142+
tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),
143+
tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),
144+
tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)
145+
]
146+
147+
conformer.fit(
148+
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
149+
validation_data=eval_data_loader, callbacks=callbacks,
150+
steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps
151+
)

examples/conformer/train_tpu_keras_subword_conformer.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
parser.add_argument("--bs", type=int, default=None, help="Batch size per replica")
3434

35-
parser.add_argument("--spx", type=int, default=1, help="Steps per execution for maximizing TPU performance")
35+
parser.add_argument("--spx", type=int, default=50, help="Steps per execution for maximizing TPU performance")
3636

3737
parser.add_argument("--tpu_address", type=str, default=None, help="TPU address. Leave None on Colab")
3838

@@ -78,11 +78,13 @@
7878

7979
train_dataset = ASRTFRecordDatasetKeras(
8080
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
81-
**vars(config.learning_config.train_dataset_config)
81+
**vars(config.learning_config.train_dataset_config),
82+
indefinite=True
8283
)
8384
eval_dataset = ASRTFRecordDatasetKeras(
8485
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
85-
**vars(config.learning_config.eval_dataset_config)
86+
**vars(config.learning_config.eval_dataset_config),
87+
indefinite=True
8688
)
8789

8890
if args.compute_lengths:
@@ -93,10 +95,14 @@
9395
train_dataset.load_metadata(args.metadata_prefix)
9496
eval_dataset.load_metadata(args.metadata_prefix)
9597

98+
batch_size = args.bs if args.bs is not None else config.learning_config.running_config.batch_size
99+
global_batch_size = batch_size
100+
global_batch_size *= strategy.num_replicas_in_sync
101+
102+
train_data_loader = train_dataset.create(global_batch_size)
103+
eval_data_loader = eval_dataset.create(global_batch_size)
104+
96105
with strategy.scope():
97-
batch_size = args.bs if args.bs is not None else config.learning_config.running_config.batch_size
98-
global_batch_size = batch_size
99-
global_batch_size *= strategy.num_replicas_in_sync
100106
# build model
101107
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
102108
conformer._build(speech_featurizer.shape, prediction_shape=text_featurizer.prepand_shape, batch_size=global_batch_size)
@@ -120,17 +126,14 @@
120126
blank=text_featurizer.blank
121127
)
122128

123-
train_data_loader = train_dataset.create(global_batch_size)
124-
eval_data_loader = eval_dataset.create(global_batch_size)
125-
126-
callbacks = [
127-
tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),
128-
tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),
129-
tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)
130-
]
129+
callbacks = [
130+
tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),
131+
tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),
132+
tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)
133+
]
131134

132-
conformer.fit(
133-
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
134-
validation_data=eval_data_loader, callbacks=callbacks,
135-
steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps
136-
)
135+
conformer.fit(
136+
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
137+
validation_data=eval_data_loader, callbacks=callbacks,
138+
steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps
139+
)

examples/conformer/train_tpu_subword_conformer.py

Lines changed: 0 additions & 123 deletions
This file was deleted.

0 commit comments

Comments
 (0)