|
83 | 83 | ) |
84 | 84 | eval_dataset = ASRTFRecordDatasetKeras( |
85 | 85 | speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, |
86 | | - **vars(config.learning_config.eval_dataset_config) |
| 86 | + **vars(config.learning_config.eval_dataset_config), |
| 87 | + indefinite=True |
87 | 88 | ) |
88 | 89 | # Update metadata calculated from both train and eval datasets |
89 | 90 | train_dataset.load_metadata(args.metadata_prefix) |
|
99 | 100 | ) |
100 | 101 | eval_dataset = ASRSliceDatasetKeras( |
101 | 102 | speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, |
102 | | - **vars(config.learning_config.eval_dataset_config) |
| 103 | + **vars(config.learning_config.eval_dataset_config), |
| 104 | + indefinite=True |
103 | 105 | ) |
104 | 106 |
|
| 107 | +global_batch_size = config.learning_config.running_config.batch_size |
| 108 | +global_batch_size *= strategy.num_replicas_in_sync |
| 109 | + |
| 110 | +train_data_loader = train_dataset.create(global_batch_size) |
| 111 | +eval_data_loader = eval_dataset.create(global_batch_size) |
| 112 | + |
105 | 113 | with strategy.scope(): |
106 | | - global_batch_size = config.learning_config.running_config.batch_size |
107 | | - global_batch_size *= strategy.num_replicas_in_sync |
108 | 114 | # build model |
109 | 115 | contextnet = ContextNet(**config.model_config, vocabulary_size=text_featurizer.num_classes) |
110 | 116 | contextnet._build(speech_featurizer.shape) |
|
128 | 134 | blank=text_featurizer.blank |
129 | 135 | ) |
130 | 136 |
|
131 | | - train_data_loader = train_dataset.create(global_batch_size) |
132 | | - eval_data_loader = eval_dataset.create(global_batch_size) |
133 | | - |
134 | | - callbacks = [ |
135 | | - tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), |
136 | | - tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), |
137 | | - tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) |
138 | | - ] |
139 | | - |
140 | | - contextnet.fit( |
141 | | - train_data_loader, epochs=config.learning_config.running_config.num_epochs, |
142 | | - validation_data=eval_data_loader, callbacks=callbacks, |
143 | | - steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps |
144 | | - ) |
| 137 | +callbacks = [ |
| 138 | + tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), |
| 139 | + tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), |
| 140 | + tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) |
| 141 | +] |
| 142 | + |
| 143 | +contextnet.fit( |
| 144 | + train_data_loader, epochs=config.learning_config.running_config.num_epochs, |
| 145 | + validation_data=eval_data_loader, callbacks=callbacks, |
| 146 | + steps_per_epoch=train_dataset.total_steps, validation_steps=eval_dataset.total_steps |
| 147 | +) |
0 commit comments