-
Notifications
You must be signed in to change notification settings - Fork 78
Description
In the BasicWorkflow class, the validation set is constructed via validation_data = OfflineDataset(data=validation_data, batch_size=dataset.batch_size, adapter=self.adapter) in the _fit method. This treats the validation data as a standard OfflineDataset and implies shuffling on epoch end. Thus, stochasticity is introduced by modifying the composition of the evaluated batches, which can lead to substantial changes between multiple validation loss calculations during training even for frozen networks.
E.g., adapting Linear_Regression_Starter.ipynb to use offline training with a learning rate of 0 (and no standardization in the adapter),
validation_data = workflow.simulate(1000)
history = workflow.fit_offline(
data=training_data,
epochs=5,
batch_size=32,
validation_data=validation_data,
)leads to highly variable validation losses:
[19.465839385986328,
10.662302017211914,
16.612285614013672,
7.247644901275635,
17.282154083251953]Setting the number of validation sets equal the batch size, either via validation_data = workflow.simulate(32) or via batch_size=1000, enforces equal batches despite shuffling, resulting in stable validation losses:
[12.364879608154297,
12.364879608154297,
12.364879608154297,
12.364879608154297,
12.364879608154297]A pull request with a simple fix will follow.