Skip to content

Commit cfbd6c9

Browse files
committed
Add shuffle_dataset argument
1 parent cccf75a commit cfbd6c9

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

bayesflow/datasets/disk_dataset.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
adapter: Adapter | None,
3838
stage: str = "training",
3939
augmentations: Mapping[str, Callable] | Callable = None,
40+
shuffle_dataset: bool = True,
4041
**kwargs,
4142
):
4243
"""
@@ -67,6 +68,9 @@ def __init__(
6768
6869
Note - augmentations are applied before the adapter is called and are generally
6970
transforms that you only want to apply during training.
71+
shuffle_dataset : bool, default=True
72+
Whether to shuffle the dataset at initialization and at the end of each epoch. Should be set to `False`
73+
for validation and test datasets to ensure consistent ordering of data.
7074
**kwargs
7175
Additional keyword arguments passed to the base `PyDataset`.
7276
"""
@@ -79,8 +83,9 @@ def __init__(
7983
self.stage = stage
8084

8185
self.augmentations = augmentations
82-
83-
self.shuffle()
86+
self.shuffle_dataset = shuffle_dataset
87+
if self.shuffle_dataset:
88+
self.shuffle()
8489

8590
def __getitem__(self, item) -> dict[str, np.ndarray]:
8691
if not 0 <= item < self.num_batches:
@@ -108,7 +113,8 @@ def __getitem__(self, item) -> dict[str, np.ndarray]:
108113
return batch
109114

110115
def on_epoch_end(self):
111-
self.shuffle()
116+
if self.shuffle_dataset:
117+
self.shuffle()
112118

113119
@property
114120
def num_batches(self):

0 commit comments

Comments
 (0)