Skip to content

Commit b0cc951

Browse files
committed
✍️ update tpu keras training script
1 parent 4238f59 commit b0cc951

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

examples/conformer/train_tpu_keras_subword_conformer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,12 @@
9292
eval_dataset.load_max_lengths(args.max_lengths_prefix)
9393

9494
with strategy.scope():
95-
global_batch_size = config.learning_config.running_config.batch_size
95+
batch_size = args.bs if args.bs is not None else config.learning_config.running_config.batch_size
96+
global_batch_size = batch_size
9697
global_batch_size *= strategy.num_replicas_in_sync
9798
# build model
9899
conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes)
99-
conformer._build(speech_featurizer.shape)
100+
conformer._build(speech_featurizer.shape, prediction_shape=text_featurizer.prepand_shape, batch_size=batch_size)
100101
conformer.summary(line_length=120)
101102

102103
optimizer = tf.keras.optimizers.Adam(

0 commit comments

Comments
 (0)