Skip to content

Commit bddcf81

Browse files
committed
✍️ update dataset and scripts
1 parent bedbd4c commit bddcf81

File tree

6 files changed

+116
-105
lines changed

6 files changed

+116
-105
lines changed

examples/conformer/train_keras_subword_conformer.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,17 @@
104104
)
105105
eval_dataset = ASRSliceDatasetKeras(
106106
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
107-
**vars(config.learning_config.train_dataset_config)
107+
**vars(config.learning_config.train_dataset_config),
108+
indefinite=True
108109
)
109110

111+
global_batch_size = config.learning_config.running_config.batch_size
112+
global_batch_size *= strategy.num_replicas_in_sync
113+
114+
train_data_loader = train_dataset.create(global_batch_size)
115+
eval_data_loader = eval_dataset.create(global_batch_size)
116+
110117
with strategy.scope():
111-
global_batch_size = config.learning_config.running_config.batch_size
112-
global_batch_size *= strategy.num_replicas_in_sync
113118
# build model
114119
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
115120
conformer._build(speech_featurizer.shape)
@@ -133,17 +138,14 @@
133138
blank=text_featurizer.blank
134139
)
135140

136-
train_data_loader = train_dataset.create(global_batch_size)
137-
eval_data_loader = eval_dataset.create(global_batch_size)
138-
139-
callbacks = [
140-
tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),
141-
tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),
142-
tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)
143-
]
144-
145-
conformer.fit(
146-
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
147-
validation_data=eval_data_loader, callbacks=callbacks,
148-
steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps
149-
)
141+
callbacks = [
142+
tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),
143+
tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),
144+
tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)
145+
]
146+
147+
conformer.fit(
148+
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
149+
validation_data=eval_data_loader, callbacks=callbacks,
150+
steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps
151+
)

examples/conformer/train_tpu_keras_subword_conformer.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@
8383
)
8484
eval_dataset = ASRTFRecordDatasetKeras(
8585
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
86-
**vars(config.learning_config.eval_dataset_config)
86+
**vars(config.learning_config.eval_dataset_config),
87+
indefinite=True
8788
)
8889

8990
if args.compute_lengths:
@@ -94,10 +95,14 @@
9495
train_dataset.load_metadata(args.metadata_prefix)
9596
eval_dataset.load_metadata(args.metadata_prefix)
9697

98+
batch_size = args.bs if args.bs is not None else config.learning_config.running_config.batch_size
99+
global_batch_size = batch_size
100+
global_batch_size *= strategy.num_replicas_in_sync
101+
102+
train_data_loader = train_dataset.create(global_batch_size)
103+
eval_data_loader = eval_dataset.create(global_batch_size)
104+
97105
with strategy.scope():
98-
batch_size = args.bs if args.bs is not None else config.learning_config.running_config.batch_size
99-
global_batch_size = batch_size
100-
global_batch_size *= strategy.num_replicas_in_sync
101106
# build model
102107
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
103108
conformer._build(speech_featurizer.shape, prediction_shape=text_featurizer.prepand_shape, batch_size=global_batch_size)
@@ -121,17 +126,14 @@
121126
blank=text_featurizer.blank
122127
)
123128

124-
train_data_loader = train_dataset.create(global_batch_size)
125-
eval_data_loader = eval_dataset.create(global_batch_size)
129+
callbacks = [
130+
tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),
131+
tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),
132+
tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)
133+
]
126134

127-
callbacks = [
128-
tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),
129-
tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),
130-
tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)
131-
]
132-
133-
conformer.fit(
134-
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
135-
validation_data=eval_data_loader, callbacks=callbacks,
136-
steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps
137-
)
135+
conformer.fit(
136+
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
137+
validation_data=eval_data_loader, callbacks=callbacks,
138+
steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps
139+
)

examples/contextnet/train_keras_subword_contextnet.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@
8383
)
8484
eval_dataset = ASRTFRecordDatasetKeras(
8585
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
86-
**vars(config.learning_config.eval_dataset_config)
86+
**vars(config.learning_config.eval_dataset_config),
87+
indefinite=True
8788
)
8889
# Update metadata calculated from both train and eval datasets
8990
train_dataset.load_metadata(args.metadata_prefix)
@@ -99,12 +100,17 @@
99100
)
100101
eval_dataset = ASRSliceDatasetKeras(
101102
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
102-
**vars(config.learning_config.eval_dataset_config)
103+
**vars(config.learning_config.eval_dataset_config),
104+
indefinite=True
103105
)
104106

107+
global_batch_size = config.learning_config.running_config.batch_size
108+
global_batch_size *= strategy.num_replicas_in_sync
109+
110+
train_data_loader = train_dataset.create(global_batch_size)
111+
eval_data_loader = eval_dataset.create(global_batch_size)
112+
105113
with strategy.scope():
106-
global_batch_size = config.learning_config.running_config.batch_size
107-
global_batch_size *= strategy.num_replicas_in_sync
108114
# build model
109115
contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes)
110116
contextnet._build(speech_featurizer.shape)
@@ -128,17 +134,14 @@
128134
blank=text_featurizer.blank
129135
)
130136

131-
train_data_loader = train_dataset.create(global_batch_size)
132-
eval_data_loader = eval_dataset.create(global_batch_size)
133-
134-
callbacks = [
135-
tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),
136-
tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),
137-
tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)
138-
]
139-
140-
contextnet.fit(
141-
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
142-
validation_data=eval_data_loader, callbacks=callbacks,
143-
steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps
144-
)
137+
callbacks = [
138+
tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),
139+
tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),
140+
tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)
141+
]
142+
143+
contextnet.fit(
144+
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
145+
validation_data=eval_data_loader, callbacks=callbacks,
146+
steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps
147+
)

examples/deepspeech2/train_keras_ds2.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,18 @@
8181
)
8282
eval_dataset = ASRSliceDatasetKeras(
8383
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
84-
**vars(config.learning_config.eval_dataset_config)
84+
**vars(config.learning_config.eval_dataset_config),
85+
indefinite=True
8586
)
8687

88+
global_batch_size = config.learning_config.running_config.batch_size
89+
global_batch_size *= strategy.num_replicas_in_sync
90+
91+
train_data_loader = train_dataset.create(global_batch_size)
92+
eval_data_loader = eval_dataset.create(global_batch_size)
93+
8794
# Build DS2 model
8895
with strategy.scope():
89-
global_batch_size = config.learning_config.running_config.batch_size
90-
global_batch_size *= strategy.num_replicas_in_sync
91-
9296
ds2_model = DeepSpeech2(**config.model_config, vocabulary_size=text_featurizer.num_classes)
9397
ds2_model._build(speech_featurizer.shape)
9498
ds2_model.summary(line_length=120)
@@ -100,17 +104,14 @@
100104
blank=text_featurizer.blank
101105
)
102106

103-
train_data_loader = train_dataset.create(global_batch_size)
104-
eval_data_loader = eval_dataset.create(global_batch_size)
105-
106-
callbacks = [
107-
tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),
108-
tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),
109-
tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)
110-
]
111-
112-
ds2_model.fit(
113-
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
114-
validation_data=eval_data_loader, callbacks=callbacks,
115-
steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps
116-
)
107+
callbacks = [
108+
tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),
109+
tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),
110+
tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)
111+
]
112+
113+
ds2_model.fit(
114+
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
115+
validation_data=eval_data_loader, callbacks=callbacks,
116+
steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps
117+
)

examples/jasper/train_keras_jasper.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,17 @@
8383
)
8484
eval_dataset = ASRSliceDatasetKeras(
8585
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
86-
**vars(config.learning_config.eval_dataset_config)
86+
**vars(config.learning_config.eval_dataset_config),
87+
indefinite=True
8788
)
8889

89-
with strategy.scope():
90-
global_batch_size = config.learning_config.running_config.batch_size
91-
global_batch_size *= strategy.num_replicas_in_sync
90+
global_batch_size = config.learning_config.running_config.batch_size
91+
global_batch_size *= strategy.num_replicas_in_sync
92+
93+
train_data_loader = train_dataset.create(global_batch_size)
94+
eval_data_loader = eval_dataset.create(global_batch_size)
9295

96+
with strategy.scope():
9397
jasper = Jasper(**config.model_config, vocabulary_size=text_featurizer.num_classes)
9498
jasper._build(speech_featurizer.shape)
9599
jasper.summary(line_length=120)
@@ -101,17 +105,14 @@
101105
blank=text_featurizer.blank
102106
)
103107

104-
train_data_loader = train_dataset.create(global_batch_size)
105-
eval_data_loader = eval_dataset.create(global_batch_size)
106-
107-
callbacks = [
108-
tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),
109-
tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),
110-
tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)
111-
]
112-
113-
jasper.fit(
114-
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
115-
validation_data=eval_data_loader, callbacks=callbacks,
116-
steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps
117-
)
108+
callbacks = [
109+
tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),
110+
tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),
111+
tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)
112+
]
113+
114+
jasper.fit(
115+
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
116+
validation_data=eval_data_loader, callbacks=callbacks,
117+
steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps
118+
)

examples/streaming_transducer/train_keras_subword_streaming_transducer.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,17 @@
9797
)
9898
eval_dataset = ASRSliceDatasetKeras(
9999
speech_featurizer=speech_featurizer, text_featurizer=text_featurizer,
100-
**vars(config.learning_config.eval_dataset_config)
100+
**vars(config.learning_config.eval_dataset_config),
101+
indefinite=True
101102
)
102103

104+
global_batch_size = config.learning_config.running_config.batch_size
105+
global_batch_size *= strategy.num_replicas_in_sync
106+
107+
train_data_loader = train_dataset.create(global_batch_size)
108+
eval_data_loader = eval_dataset.create(global_batch_size)
109+
103110
with strategy.scope():
104-
global_batch_size = config.learning_config.running_config.batch_size
105-
global_batch_size *= strategy.num_replicas_in_sync
106111
# build model
107112
streaming_transducer = StreamingTransducer(
108113
**config.model_config,
@@ -120,17 +125,14 @@
120125
blank=text_featurizer.blank
121126
)
122127

123-
train_data_loader = train_dataset.create(global_batch_size)
124-
eval_data_loader = eval_dataset.create(global_batch_size)
125-
126-
callbacks = [
127-
tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),
128-
tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),
129-
tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)
130-
]
131-
132-
streaming_transducer.fit(
133-
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
134-
validation_data=eval_data_loader, callbacks=callbacks,
135-
steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps
136-
)
128+
callbacks = [
129+
tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint),
130+
tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir),
131+
tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard)
132+
]
133+
134+
streaming_transducer.fit(
135+
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
136+
validation_data=eval_data_loader, callbacks=callbacks,
137+
steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps
138+
)

0 commit comments

Comments
 (0)