Skip to content

Commit 4be4a6e

Browse files
committed
✍️ test and update train script
1 parent d028489 commit 4be4a6e

File tree

9 files changed

+43
-71
lines changed

9 files changed

+43
-71
lines changed

examples/conformer/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ learning_config:
115115
checkpoint:
116116
filepath: /mnt/e/Models/local/conformer/checkpoints/{epoch:02d}.h5
117117
save_best_only: True
118-
save_weights_only: False
118+
save_weights_only: True
119119
save_freq: epoch
120120
states_dir: /mnt/e/Models/local/conformer/states
121121
tensorboard:

examples/contextnet/config.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ learning_config:
228228
test_dataset_config:
229229
use_tf: True
230230
data_paths:
231-
- /mnt/e/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv
231+
- /mnt/h/ML/Datasets/ASR/Raw/LibriSpeech/test-clean/transcripts.tsv
232232
tfrecords_dir: null
233233
shuffle: False
234234
cache: True
@@ -248,7 +248,7 @@ learning_config:
248248
checkpoint:
249249
filepath: /mnt/e/Models/local/contextnet/checkpoints/{epoch:02d}.h5
250250
save_best_only: True
251-
save_weights_only: False
251+
save_weights_only: True
252252
save_freq: epoch
253253
states_dir: /mnt/e/Models/local/contextnet/states
254254
tensorboard:

examples/deepspeech2/config.yml

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ learning_config:
5252
train_dataset_config:
5353
use_tf: True
5454
data_paths:
55-
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv
56-
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
55+
- /mnt/h/ML/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
56+
tfrecords_dir: null
5757
shuffle: True
5858
cache: True
5959
buffer_size: 100
@@ -62,10 +62,8 @@ learning_config:
6262

6363
eval_dataset_config:
6464
use_tf: True
65-
data_paths:
66-
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-clean/transcripts.tsv
67-
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv
68-
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
65+
data_paths: null
66+
tfrecords_dir: null
6967
shuffle: False
7068
cache: True
7169
buffer_size: 100
@@ -75,8 +73,8 @@ learning_config:
7573
test_dataset_config:
7674
use_tf: True
7775
data_paths:
78-
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv
79-
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
76+
- /mnt/h/ML/Datasets/ASR/Raw/LibriSpeech/test-clean/transcripts.tsv
77+
tfrecords_dir: null
8078
shuffle: False
8179
cache: True
8280
buffer_size: 100
@@ -91,19 +89,14 @@ learning_config:
9189
running_config:
9290
batch_size: 4
9391
num_epochs: 20
94-
accumulation_steps: 8
95-
outdir: /mnt/Miscellanea/Models/local/deepspeech2
96-
log_interval_steps: 400
97-
save_interval_steps: 400
98-
eval_interval_steps: 800
9992
checkpoint:
100-
filepath: /mnt/Miscellanea/Models/local/deepspeech2/checkpoints/{epoch:02d}.h5
93+
filepath: /mnt/e/Models/local/deepspeech2/checkpoints/{epoch:02d}.h5
10194
save_best_only: True
102-
save_weights_only: False
95+
save_weights_only: True
10396
save_freq: epoch
104-
states_dir: /mnt/Miscellanea/Models/local/deepspeech2/states
97+
states_dir: /mnt/e/Models/local/deepspeech2/states
10598
tensorboard:
106-
log_dir: /mnt/Miscellanea/Models/local/deepspeech2/tensorboard
99+
log_dir: /mnt/e/Models/local/deepspeech2/tensorboard
107100
histogram_freq: 1
108101
write_graph: True
109102
write_images: True

examples/jasper/config.yml

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ learning_config:
5959
train_dataset_config:
6060
use_tf: True
6161
data_paths:
62-
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv
63-
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
62+
- /mnt/h/ML/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
63+
tfrecords_dir: null
6464
shuffle: True
6565
cache: True
6666
buffer_size: 100
@@ -69,10 +69,8 @@ learning_config:
6969

7070
eval_dataset_config:
7171
use_tf: True
72-
data_paths:
73-
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-clean/transcripts.tsv
74-
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv
75-
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
72+
data_paths: null
73+
tfrecords_dir: null
7674
shuffle: False
7775
cache: True
7876
buffer_size: 100
@@ -82,8 +80,8 @@ learning_config:
8280
test_dataset_config:
8381
use_tf: True
8482
data_paths:
85-
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv
86-
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
83+
- /mnt/h/ML/Datasets/ASR/Raw/LibriSpeech/test-clean/transcripts.tsv
84+
tfrecords_dir: null
8785
shuffle: False
8886
cache: True
8987
buffer_size: 100
@@ -98,19 +96,14 @@ learning_config:
9896
running_config:
9997
batch_size: 4
10098
num_epochs: 20
101-
accumulation_steps: 8
102-
outdir: /mnt/Miscellanea/Models/local/jasper
103-
log_interval_steps: 400
104-
save_interval_steps: 400
105-
eval_interval_steps: 800
10699
checkpoint:
107-
filepath: /mnt/Miscellanea/Models/local/jasper/checkpoints/{epoch:02d}.h5
100+
filepath: /mnt/e/Models/local/jasper/checkpoints/{epoch:02d}.h5
108101
save_best_only: True
109-
save_weights_only: False
102+
save_weights_only: True
110103
save_freq: epoch
111-
states_dir: /mnt/Miscellanea/Models/local/jasper/states
104+
states_dir: /mnt/e/Models/local/jasper/states
112105
tensorboard:
113-
log_dir: /mnt/Miscellanea/Models/local/jasper/tensorboard
106+
log_dir: /mnt/e/Models/local/jasper/tensorboard
114107
histogram_freq: 1
115108
write_graph: True
116109
write_images: True

examples/rnn_transducer/config.yml

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ learning_config:
6464
num_masks: 1
6565
mask_factor: 27
6666
data_paths:
67-
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv
68-
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
67+
- /mnt/h/ML/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
68+
tfrecords_dir: null
6969
shuffle: True
7070
cache: True
7171
buffer_size: 100
@@ -74,10 +74,8 @@ learning_config:
7474

7575
eval_dataset_config:
7676
use_tf: True
77-
data_paths:
78-
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-clean/transcripts.tsv
79-
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv
80-
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
77+
data_paths: null
78+
tfrecords_dir: null
8179
shuffle: False
8280
cache: True
8381
buffer_size: 100
@@ -87,8 +85,8 @@ learning_config:
8785
test_dataset_config:
8886
use_tf: True
8987
data_paths:
90-
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv
91-
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
88+
- /mnt/h/ML/Datasets/ASR/Raw/LibriSpeech/test-clean/transcripts.tsv
89+
tfrecords_dir: null
9290
shuffle: False
9391
cache: True
9492
buffer_size: 100
@@ -102,20 +100,15 @@ learning_config:
102100

103101
running_config:
104102
batch_size: 2
105-
accumulation_steps: 1
106103
num_epochs: 20
107-
outdir: /mnt/Miscellanea/Models/local/streaming_transducer
108-
log_interval_steps: 300
109-
eval_interval_steps: 500
110-
save_interval_steps: 1000
111104
checkpoint:
112-
filepath: /mnt/Miscellanea/Models/local/streaming_transducer/checkpoints/{epoch:02d}.h5
105+
filepath: /mnt/e/Models/local/rnn_transducer/checkpoints/{epoch:02d}.h5
113106
save_best_only: True
114-
save_weights_only: False
107+
save_weights_only: True
115108
save_freq: epoch
116-
states_dir: /mnt/Miscellanea/Models/local/streaming_transducer/states
109+
states_dir: /mnt/e/Models/local/rnn_transducer/states
117110
tensorboard:
118-
log_dir: /mnt/Miscellanea/Models/local/streaming_transducer/tensorboard
111+
log_dir: /mnt/e/Models/local/rnn_transducer/tensorboard
119112
histogram_freq: 1
120113
write_graph: True
121114
write_images: True

examples/rnn_transducer/train.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import os
16-
import math
1716
import argparse
1817
from tensorflow_asr.utils import env_util
1918

@@ -58,7 +57,6 @@
5857
from tensorflow_asr.datasets import asr_dataset
5958
from tensorflow_asr.featurizers import speech_featurizers, text_featurizers
6059
from tensorflow_asr.models.transducer.rnn_transducer import RnnTransducer
61-
from tensorflow_asr.optimizers.schedules import TransformerSchedule
6260

6361
config = Config(args.config)
6462
speech_featurizer = speech_featurizers.TFSpeechFeaturizer(config.speech_config)
@@ -118,18 +116,8 @@
118116
rnn_transducer = RnnTransducer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
119117
rnn_transducer._build(speech_featurizer.shape)
120118
rnn_transducer.summary(line_length=100)
121-
122-
optimizer = tf.keras.optimizers.Adam(
123-
TransformerSchedule(
124-
d_model=rnn_transducer.dmodel,
125-
warmup_steps=config.learning_config.optimizer_config.pop("warmup_steps", 10000),
126-
max_lr=(0.05 / math.sqrt(rnn_transducer.dmodel))
127-
),
128-
**config.learning_config.optimizer_config
129-
)
130-
131119
rnn_transducer.compile(
132-
optimizer=optimizer,
120+
optimizer=config.learning_config.optimizer_config,
133121
experimental_steps_per_execution=args.spx,
134122
global_batch_size=global_batch_size,
135123
blank=text_featurizer.blank

tensorflow_asr/models/ctc/jasper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def __init__(self,
357357
strides=1, padding="same",
358358
kernel_regularizer=kernel_regularizer,
359359
bias_regularizer=bias_regularizer,
360-
name=f"{self.name}_logits"
360+
name=f"{name}_logits"
361361
),
362362
vocabulary_size=vocabulary_size,
363363
name=name,

tensorflow_asr/utils/data_util.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,15 @@ def create_inputs(inputs: tf.Tensor,
2121
inputs_length: tf.Tensor,
2222
predictions: tf.Tensor = None,
2323
predictions_length: tf.Tensor = None) -> dict:
24-
return {
24+
data = {
2525
"inputs": inputs,
2626
"inputs_length": inputs_length,
27-
"predictions": predictions,
28-
"predictions_length": predictions_length
2927
}
28+
if predictions is not None:
29+
data["predictions"] = predictions
30+
if predictions_length is not None:
31+
data["predictions_length"] = predictions_length
32+
return data
3033

3134

3235
def create_logits(logits: tf.Tensor, logits_length: tf.Tensor) -> dict:

tensorflow_asr/utils/env_util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ def setup_strategy(devices):
4949
tf.distribute.Strategy: MirroredStrategy for training one or multiple gpus
5050
"""
5151
setup_devices(devices)
52+
if has_tpu():
53+
return setup_tpu()
5254
return tf.distribute.MirroredStrategy()
5355

5456

0 commit comments

Comments
 (0)