Skip to content

Commit 1785335

Browse files
authored
Merge pull request #43 from TensorSpeech/dev/gradacc
Supported Gradients Accumulation
2 parents 9a8e16b + 263855e commit 1785335

File tree

18 files changed

+231
-94
lines changed

18 files changed

+231
-94
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ TensorFlowASR implements some automatic speech recognition architectures such as
1919

2020
## What's New?
2121

22+
- (11/14/2020) Supported Gradient Accumulation for Training in Larger Batch Size
2223
- (11/3/2020) Reduce differences between `librosa.stft` and `tf.signal.stft`
2324
- (10/31/2020) Update DeepSpeech2 and Supported Jasper [https://arxiv.org/abs/1904.03288](https://arxiv.org/abs/1904.03288)
2425
- (10/18/2020) Supported Streaming Transducer [https://arxiv.org/abs/1811.06621](https://arxiv.org/abs/1811.06621)

examples/conformer/config.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,12 @@ learning_config:
6868

6969
dataset_config:
7070
train_paths:
71-
- /mnt/Data/ML/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
71+
- /mnt/d/SpeechProcessing/Datasets/LibriSpeech/train-clean-100/transcripts.tsv
7272
eval_paths:
73-
- /mnt/Data/ML/ASR/Raw/LibriSpeech/dev-clean/transcripts.tsv
74-
- /mnt/Data/ML/ASR/Raw/LibriSpeech/dev-other/transcripts.tsv
73+
- /mnt/d/SpeechProcessing/Datasets/LibriSpeech/dev-clean/transcripts.tsv
74+
- /mnt/d/SpeechProcessing/Datasets/LibriSpeech/dev-other/transcripts.tsv
7575
test_paths:
76-
- /mnt/Data/ML/ASR/Raw/LibriSpeech/test-clean/transcripts.tsv
76+
- /mnt/d/SpeechProcessing/Datasets/LibriSpeech/test-clean/transcripts.tsv
7777
tfrecords_dir: null
7878

7979
optimizer_config:
@@ -83,10 +83,10 @@ learning_config:
8383
epsilon: 1e-9
8484

8585
running_config:
86-
batch_size: 2
87-
accumulation_steps: 1
86+
batch_size: 4
87+
accumulation_steps: 4
8888
num_epochs: 20
89-
outdir: /mnt/Projects/asrk16/trained/local/librispeech/conformer
89+
outdir: /mnt/d/SpeechProcessing/Trained/local/conformer
9090
log_interval_steps: 300
9191
eval_interval_steps: 500
9292
save_interval_steps: 1000

examples/conformer/tflite_subword_conformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
# build model
5959
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
6060
conformer._build(speech_featurizer.shape)
61-
conformer.load_weights(args.saved)
61+
conformer.load_weights(args.saved, by_name=True)
6262
conformer.summary(line_length=150)
6363
conformer.add_featurizers(speech_featurizer, text_featurizer)
6464

examples/conformer/train_ga_conformer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@
4141
parser.add_argument("--ebs", type=int, default=None,
4242
help="Evaluation batch size per replica")
4343

44+
parser.add_argument("--acs", type=int, default=None,
45+
help="Train accumulation steps")
46+
4447
parser.add_argument("--devices", type=int, nargs="*", default=[0],
4548
help="Devices' ids to apply distributed training")
4649

@@ -125,4 +128,5 @@
125128
conformer_trainer.compile(model=conformer, optimizer=optimizer,
126129
max_to_keep=args.max_ckpts)
127130

128-
conformer_trainer.fit(train_dataset, eval_dataset, train_bs=args.tbs, eval_bs=args.ebs)
131+
conformer_trainer.fit(train_dataset, eval_dataset,
132+
train_bs=args.tbs, eval_bs=args.ebs, train_acs=args.acs)

examples/conformer/train_ga_subword_conformer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@
4141
parser.add_argument("--ebs", type=int, default=None,
4242
help="Evaluation batch size per replica")
4343

44+
parser.add_argument("--acs", type=int, default=None,
45+
help="Train accumulation steps")
46+
4447
parser.add_argument("--devices", type=int, nargs="*", default=[0],
4548
help="Devices' ids to apply distributed training")
4649

@@ -141,4 +144,5 @@
141144
conformer_trainer.compile(model=conformer, optimizer=optimizer,
142145
max_to_keep=args.max_ckpts)
143146

144-
conformer_trainer.fit(train_dataset, eval_dataset, train_bs=args.tbs, eval_bs=args.ebs)
147+
conformer_trainer.fit(train_dataset, eval_dataset,
148+
train_bs=args.tbs, eval_bs=args.ebs, train_acs=args.acs)

examples/deepspeech2/config.yml

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ speech_config:
2424
normalize_per_feature: False
2525

2626
decoder_config:
27-
vocabulary: ./vocabularies/vietnamese.characters
27+
vocabulary: null
2828
blank_at_zero: False
2929
beam_width: 500
3030
lm_config:
31-
model_path: /mnt/Data/ML/NLP/vntc_asr_5gram_trie.binary
31+
model_path: null
3232
alpha: 2.0
3333
beta: 1.0
3434

@@ -53,12 +53,13 @@ learning_config:
5353

5454
dataset_config:
5555
train_paths:
56-
- /mnt/Data/ML/ASR/Preprocessed/Vivos/train/train_transcripts.tsv
56+
- /mnt/d/SpeechProcessing/Datasets/LibriSpeech/train-clean-100/transcripts.tsv
5757
eval_paths:
58-
- /mnt/Data/ML/ASR/Preprocessed/Vivos/train/eval_transcripts.tsv
58+
- /mnt/d/SpeechProcessing/Datasets/LibriSpeech/dev-clean/transcripts.tsv
59+
- /mnt/d/SpeechProcessing/Datasets/LibriSpeech/dev-other/transcripts.tsv
5960
test_paths:
60-
- /mnt/Data/ML/ASR/Preprocessed/Vivos/test/transcripts.tsv
61-
tfrecords_dir: /mnt/Data/ML/ASR/Preprocessed/Vivos/TFRecords
61+
- /mnt/d/SpeechProcessing/Datasets/LibriSpeech/test-clean/transcripts.tsv
62+
tfrecords_dir: null
6263

6364
optimizer_config:
6465
class_name: adam
@@ -68,7 +69,7 @@ learning_config:
6869
running_config:
6970
batch_size: 8
7071
num_epochs: 20
71-
outdir: /mnt/Projects/asrk16/trained/local/vivos
72+
outdir: /mnt/d/SpeechProcessing/Trained/local/deepspeech2
7273
log_interval_steps: 400
7374
save_interval_steps: 400
7475
eval_interval_steps: 800

examples/jasper/config.yml

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ speech_config:
2424
normalize_per_feature: False
2525

2626
decoder_config:
27-
vocabulary: ./vocabularies/vietnamese.characters
27+
vocabulary: null
2828
blank_at_zero: False
2929
beam_width: 500
3030
lm_config:
31-
model_path: /mnt/Data/ML/NLP/vntc_asr_5gram_trie.binary
31+
model_path: null
3232
alpha: 2.0
3333
beta: 1.0
3434

@@ -60,12 +60,13 @@ learning_config:
6060

6161
dataset_config:
6262
train_paths:
63-
- /mnt/Data/ML/ASR/Preprocessed/Vivos/train/train_transcripts.tsv
63+
- /mnt/d/SpeechProcessing/Datasets/LibriSpeech/train-clean-100/transcripts.tsv
6464
eval_paths:
65-
- /mnt/Data/ML/ASR/Preprocessed/Vivos/train/eval_transcripts.tsv
65+
- /mnt/d/SpeechProcessing/Datasets/LibriSpeech/dev-clean/transcripts.tsv
66+
- /mnt/d/SpeechProcessing/Datasets/LibriSpeech/dev-other/transcripts.tsv
6667
test_paths:
67-
- /mnt/Data/ML/ASR/Preprocessed/Vivos/test/transcripts.tsv
68-
tfrecords_dir: /mnt/Data/ML/ASR/Preprocessed/Vivos/TFRecords
68+
- /mnt/d/SpeechProcessing/Datasets/LibriSpeech/test-clean/transcripts.tsv
69+
tfrecords_dir: null
6970

7071
optimizer_config:
7172
class_name: adam
@@ -75,7 +76,7 @@ learning_config:
7576
running_config:
7677
batch_size: 8
7778
num_epochs: 20
78-
outdir: /mnt/Projects/asrk16/trained/local/jasper
79+
outdir: /mnt/d/SpeechProcessing/Trained/local/jasper
7980
log_interval_steps: 400
8081
save_interval_steps: 400
8182
eval_interval_steps: 800

examples/streaming_transducer/config.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,12 @@ learning_config:
6363

6464
dataset_config:
6565
train_paths:
66-
- /mnt/Data/ML/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
66+
- /mnt/d/SpeechProcessing/Datasets/LibriSpeech/train-clean-100/transcripts.tsv
6767
eval_paths:
68-
- /mnt/Data/ML/ASR/Raw/LibriSpeech/dev-clean/transcripts.tsv
69-
- /mnt/Data/ML/ASR/Raw/LibriSpeech/dev-other/transcripts.tsv
68+
- /mnt/d/SpeechProcessing/Datasets/LibriSpeech/dev-clean/transcripts.tsv
69+
- /mnt/d/SpeechProcessing/Datasets/LibriSpeech/dev-other/transcripts.tsv
7070
test_paths:
71-
- /mnt/Data/ML/ASR/Raw/LibriSpeech/test-clean/transcripts.tsv
71+
- /mnt/d/SpeechProcessing/Datasets/LibriSpeech/test-clean/transcripts.tsv
7272
tfrecords_dir: null
7373

7474
optimizer_config:
@@ -80,7 +80,7 @@ learning_config:
8080
batch_size: 2
8181
accumulation_steps: 1
8282
num_epochs: 20
83-
outdir: /mnt/Projects/asrk16/trained/local/librispeech/streaming_transducer
83+
outdir: /mnt/SpeechProcessing/Trained/local/streaming_transducer
8484
log_interval_steps: 300
8585
eval_interval_steps: 500
8686
save_interval_steps: 1000

examples/streaming_transducer/train_ga_streaming_transducer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
parser.add_argument("--ebs", type=int, default=None,
4141
help="Evaluation batch size per replica")
4242

43+
parser.add_argument("--acs", type=int, default=None,
44+
help="Train accumulation steps")
45+
4346
parser.add_argument("--devices", type=int, nargs="*", default=[0],
4447
help="Devices' ids to apply distributed training")
4548

@@ -116,4 +119,5 @@
116119
streaming_transducer_trainer.compile(model=streaming_transducer, optimizer=optimizer,
117120
max_to_keep=args.max_ckpts)
118121

119-
streaming_transducer_trainer.fit(train_dataset, eval_dataset, train_bs=args.tbs, eval_bs=args.ebs)
122+
streaming_transducer_trainer.fit(train_dataset, eval_dataset,
123+
train_bs=args.tbs, eval_bs=args.ebs, train_acs=args.acs)

examples/streaming_transducer/train_ga_subword_streaming_transducer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
parser.add_argument("--ebs", type=int, default=None,
4141
help="Evaluation batch size per replica")
4242

43+
parser.add_argument("--acs", type=int, default=None,
44+
help="Train accumulation steps")
45+
4346
parser.add_argument("--devices", type=int, nargs="*", default=[0],
4447
help="Devices' ids to apply distributed training")
4548

@@ -132,4 +135,5 @@
132135
streaming_transducer_trainer.compile(model=streaming_transducer, optimizer=optimizer,
133136
max_to_keep=args.max_ckpts)
134137

135-
streaming_transducer_trainer.fit(train_dataset, eval_dataset, train_bs=args.tbs, eval_bs=args.ebs)
138+
streaming_transducer_trainer.fit(train_dataset, eval_dataset,
139+
train_bs=args.tbs, eval_bs=args.ebs, train_acs=args.acs)

0 commit comments

Comments
 (0)