Skip to content

Commit 0fa09b6

Browse files
committed
Disable shuffling for validation sets (#481)
1 parent 56ddd99 commit 0fa09b6

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

bayesflow/datasets/offline_dataset.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def __init__(
4040
num_samples : int, optional
4141
Number of samples in the dataset. If None, it will be inferred from the data.
4242
stage : str, default="training"
43-
Current stage (e.g., "training", "validation", etc.) used by the adapter.
43+
Current stage (e.g., "training", "validation", etc.) used by the adapter and to disable shuffling during
44+
validation.
4445
augmentations : dict of str to Callable or Callable, optional
4546
Dictionary of augmentation functions to apply to each corresponding key in the batch
4647
or a function to apply to the entire batch (possibly adding new keys).
@@ -122,7 +123,8 @@ def num_batches(self) -> int | None:
122123
return int(np.ceil(self.num_samples / self.batch_size))
123124

124125
def on_epoch_end(self) -> None:
125-
self.shuffle()
126+
if self.stage != "validation":
127+
self.shuffle()
126128

127129
def shuffle(self) -> None:
128130
"""Shuffle the dataset in-place."""

0 commit comments

Comments
 (0)