diff --git a/bayesflow/datasets/disk_dataset.py b/bayesflow/datasets/disk_dataset.py index f94200dc8..d776bfe82 100644 --- a/bayesflow/datasets/disk_dataset.py +++ b/bayesflow/datasets/disk_dataset.py @@ -37,6 +37,7 @@ def __init__( adapter: Adapter | None, stage: str = "training", augmentations: Mapping[str, Callable] | Callable = None, + shuffle: bool = True, **kwargs, ): """ @@ -67,6 +68,8 @@ def __init__( Note - augmentations are applied before the adapter is called and are generally transforms that you only want to apply during training. + shuffle : bool, optional + Whether to shuffle the dataset at initialization and at the end of each epoch. Default is True. **kwargs Additional keyword arguments passed to the base `PyDataset`. """ @@ -79,8 +82,9 @@ def __init__( self.stage = stage self.augmentations = augmentations - - self.shuffle() + self._shuffle = shuffle + if self._shuffle: + self.shuffle() def __getitem__(self, item) -> dict[str, np.ndarray]: if not 0 <= item < self.num_batches: @@ -108,7 +112,8 @@ def __getitem__(self, item) -> dict[str, np.ndarray]: return batch def on_epoch_end(self): - self.shuffle() + if self._shuffle: + self.shuffle() @property def num_batches(self): diff --git a/bayesflow/datasets/offline_dataset.py b/bayesflow/datasets/offline_dataset.py index 075e5135b..3b91c5f22 100644 --- a/bayesflow/datasets/offline_dataset.py +++ b/bayesflow/datasets/offline_dataset.py @@ -24,6 +24,7 @@ def __init__( *, stage: str = "training", augmentations: Mapping[str, Callable] | Callable = None, + shuffle: bool = True, **kwargs, ): """ @@ -51,6 +52,8 @@ def __init__( Note - augmentations are applied before the adapter is called and are generally transforms that you only want to apply during training. + shuffle : bool, optional + Whether to shuffle the dataset at initialization and at the end of each epoch. Default is True. **kwargs Additional keyword arguments passed to the base `PyDataset`. """ @@ -69,8 +72,9 @@ def __init__( self.indices = np.arange(self.num_samples, dtype="int64") self.augmentations = augmentations - - self.shuffle() + self._shuffle = shuffle + if self._shuffle: + self.shuffle() def __getitem__(self, item: int) -> dict[str, np.ndarray]: """ @@ -122,7 +126,8 @@ def num_batches(self) -> int | None: return int(np.ceil(self.num_samples / self.batch_size)) def on_epoch_end(self) -> None: - self.shuffle() + if self._shuffle: + self.shuffle() def shuffle(self) -> None: """Shuffle the dataset in-place."""