|
38 | 38 |
|
39 | 39 | parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica") |
40 | 40 |
|
| 41 | +parser.add_argument("--spx", type=int, default=1, help="Steps per execution for maximizing performance") |
| 42 | + |
| 43 | +parser.add_argument("--metadata_prefix", type=str, default=None, help="Path to file containing metadata") |
| 44 | + |
41 | 45 | parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training") |
42 | 46 |
|
43 | 47 | parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision") |
|
79 | 83 | if args.tfrecords: |
80 | 84 | train_dataset = ASRTFRecordDatasetKeras( |
81 | 85 | speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, |
82 | | - **vars(config.learning_config.train_dataset_config) |
| 86 | + **vars(config.learning_config.train_dataset_config), |
| 87 | + indefinite=True |
83 | 88 | ) |
84 | 89 | eval_dataset = ASRTFRecordDatasetKeras( |
85 | 90 | speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, |
86 | 91 | **vars(config.learning_config.eval_dataset_config) |
87 | 92 | ) |
| 93 | + # Update metadata calculated from both train and eval datasets |
| 94 | + train_dataset.load_metadata(args.metadata_prefix) |
| 95 | + eval_dataset.load_metadata(args.metadata_prefix) |
| 96 | + # Use dynamic length |
| 97 | + speech_featurizer.reset_length() |
| 98 | + text_featurizer.reset_length() |
88 | 99 | else: |
89 | 100 | train_dataset = ASRSliceDatasetKeras( |
90 | 101 | speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, |
91 | | - **vars(config.learning_config.train_dataset_config) |
| 102 | + **vars(config.learning_config.train_dataset_config), |
| 103 | + indefinite=True |
92 | 104 | ) |
93 | 105 | eval_dataset = ASRSliceDatasetKeras( |
94 | 106 | speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, |
95 | | - **vars(config.learning_config.train_dataset_config) |
| 107 | + **vars(config.learning_config.train_dataset_config), |
| 108 | + indefinite=True |
96 | 109 | ) |
97 | 110 |
|
| 111 | +global_batch_size = config.learning_config.running_config.batch_size |
| 112 | +global_batch_size *= strategy.num_replicas_in_sync |
| 113 | + |
| 114 | +train_data_loader = train_dataset.create(global_batch_size) |
| 115 | +eval_data_loader = eval_dataset.create(global_batch_size) |
| 116 | + |
98 | 117 | with strategy.scope(): |
99 | | - global_batch_size = config.learning_config.running_config.batch_size |
100 | | - global_batch_size *= strategy.num_replicas_in_sync |
101 | 118 | # build model |
102 | 119 | conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) |
103 | 120 | conformer._build(speech_featurizer.shape) |
|
114 | 131 | epsilon=config.learning_config.optimizer_config["epsilon"] |
115 | 132 | ) |
116 | 133 |
|
117 | | - conformer.compile(optimizer=optimizer, global_batch_size=global_batch_size, blank=text_featurizer.blank) |
118 | | - |
119 | | - train_data_loader = train_dataset.create(global_batch_size) |
120 | | - eval_data_loader = eval_dataset.create(global_batch_size) |
121 | | - |
122 | | - callbacks = [ |
123 | | - tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), |
124 | | - tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), |
125 | | - tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) |
126 | | - ] |
127 | | - |
128 | | - conformer.fit( |
129 | | - train_data_loader, epochs=config.learning_config.running_config.num_epochs, |
130 | | - validation_data=eval_data_loader, callbacks=callbacks, |
131 | | - steps_per_epoch=train_dataset.total_steps |
| 134 | + conformer.compile( |
| 135 | + optimizer=optimizer, |
| 136 | + experimental_steps_per_execution=args.spx, |
| 137 | + global_batch_size=global_batch_size, |
| 138 | + blank=text_featurizer.blank |
132 | 139 | ) |
| 140 | + |
| 141 | +callbacks = [ |
| 142 | + tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), |
| 143 | + tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), |
| 144 | + tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) |
| 145 | +] |
| 146 | + |
| 147 | +conformer.fit( |
| 148 | + train_data_loader, epochs=config.learning_config.running_config.num_epochs, |
| 149 | + validation_data=eval_data_loader, callbacks=callbacks, |
| 150 | + steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps |
| 151 | +) |
0 commit comments