Skip to content

Commit 677bacb

Browse files
elsemlLarsKuevpratz
authored
Add shuffle parameter to datasets
Adds the option to disable data shuffling --------- Co-authored-by: Lars <[email protected]> Co-authored-by: Valentin Pratz <[email protected]>
1 parent 996a700 commit 677bacb

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

bayesflow/datasets/disk_dataset.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
adapter: Adapter | None,
3838
stage: str = "training",
3939
augmentations: Mapping[str, Callable] | Callable = None,
40+
shuffle: bool = True,
4041
**kwargs,
4142
):
4243
"""
@@ -67,6 +68,8 @@ def __init__(
6768
6869
Note - augmentations are applied before the adapter is called and are generally
6970
transforms that you only want to apply during training.
71+
shuffle : bool, optional
72+
Whether to shuffle the dataset at initialization and at the end of each epoch. Default is True.
7073
**kwargs
7174
Additional keyword arguments passed to the base `PyDataset`.
7275
"""
@@ -79,8 +82,9 @@ def __init__(
7982
self.stage = stage
8083

8184
self.augmentations = augmentations
82-
83-
self.shuffle()
85+
self._shuffle = shuffle
86+
if self._shuffle:
87+
self.shuffle()
8488

8589
def __getitem__(self, item) -> dict[str, np.ndarray]:
8690
if not 0 <= item < self.num_batches:
@@ -108,7 +112,8 @@ def __getitem__(self, item) -> dict[str, np.ndarray]:
108112
return batch
109113

110114
def on_epoch_end(self):
111-
self.shuffle()
115+
if self._shuffle:
116+
self.shuffle()
112117

113118
@property
114119
def num_batches(self):

bayesflow/datasets/offline_dataset.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __init__(
2424
*,
2525
stage: str = "training",
2626
augmentations: Mapping[str, Callable] | Callable = None,
27+
shuffle: bool = True,
2728
**kwargs,
2829
):
2930
"""
@@ -51,6 +52,8 @@ def __init__(
5152
5253
Note - augmentations are applied before the adapter is called and are generally
5354
transforms that you only want to apply during training.
55+
shuffle : bool, optional
56+
Whether to shuffle the dataset at initialization and at the end of each epoch. Default is True.
5457
**kwargs
5558
Additional keyword arguments passed to the base `PyDataset`.
5659
"""
@@ -69,8 +72,9 @@ def __init__(
6972
self.indices = np.arange(self.num_samples, dtype="int64")
7073

7174
self.augmentations = augmentations
72-
73-
self.shuffle()
75+
self._shuffle = shuffle
76+
if self._shuffle:
77+
self.shuffle()
7478

7579
def __getitem__(self, item: int) -> dict[str, np.ndarray]:
7680
"""
@@ -122,7 +126,8 @@ def num_batches(self) -> int | None:
122126
return int(np.ceil(self.num_samples / self.batch_size))
123127

124128
def on_epoch_end(self) -> None:
125-
self.shuffle()
129+
if self._shuffle:
130+
self.shuffle()
126131

127132
def shuffle(self) -> None:
128133
"""Shuffle the dataset in-place."""

0 commit comments

Comments
 (0)