diff --git a/bayesflow/datasets/__init__.py b/bayesflow/datasets/__init__.py index 2f4d21a69..71e8f44a1 100644 --- a/bayesflow/datasets/__init__.py +++ b/bayesflow/datasets/__init__.py @@ -7,7 +7,6 @@ from .offline_dataset import OfflineDataset from .online_dataset import OnlineDataset from .disk_dataset import DiskDataset -from .rounds_dataset import RoundsDataset from ..utils._docs import _add_imports_to_all diff --git a/bayesflow/datasets/disk_dataset.py b/bayesflow/datasets/disk_dataset.py index 8753e3480..f94200dc8 100644 --- a/bayesflow/datasets/disk_dataset.py +++ b/bayesflow/datasets/disk_dataset.py @@ -1,8 +1,12 @@ -import keras -import numpy as np +from collections.abc import Mapping, Callable + import os import pathlib as pl +import numpy as np + +import keras + from bayesflow.adapters import Adapter from bayesflow.utils import tree_stack, pickle_load @@ -29,11 +33,43 @@ def __init__( *, pattern: str = "*.pkl", batch_size: int, - load_fn: callable = None, + load_fn: Callable = None, adapter: Adapter | None, stage: str = "training", + augmentations: Mapping[str, Callable] | Callable = None, **kwargs, ): + """ + Initialize a DiskDataset instance for offline training using a set of simulations that + do not fit on disk. + + Parameters + ---------- + root : os.PathLike + Root directory containing the sample files. + pattern : str, default="*.pkl" + Glob pattern to match sample files. + batch_size : int + Number of samples per batch. + load_fn : Callable, optional + Function to load a single file into a sample. Defaults to `pickle_load`. + adapter : Adapter or None + Optional adapter to transform the loaded batch. + stage : str, default="training" + Current stage (e.g., "training", "validation", etc.) used by the adapter. + augmentations : dict of str to Callable or Callable, optional + Dictionary of augmentation functions to apply to each corresponding key in the batch + or a function to apply to the entire batch (possibly adding new keys). + + If you provide a dictionary of functions, each function should accept one element + of your output batch and return the corresponding transformed element. Otherwise, + your function should accept the entire dictionary output and return a dictionary. + + Note - augmentations are applied before the adapter is called and are generally + transforms that you only want to apply during training. + **kwargs + Additional keyword arguments passed to the base `PyDataset`. + """ super().__init__(**kwargs) self.batch_size = batch_size self.root = pl.Path(root) @@ -42,6 +78,8 @@ def __init__( self.files = list(map(str, self.root.glob(pattern))) self.stage = stage + self.augmentations = augmentations + self.shuffle() def __getitem__(self, item) -> dict[str, np.ndarray]: @@ -50,12 +88,20 @@ def __getitem__(self, item) -> dict[str, np.ndarray]: files = self.files[item * self.batch_size : (item + 1) * self.batch_size] - batch = [] - for file in files: - batch.append(self.load_fn(file)) + batch = [self.load_fn(file) for file in files] batch = tree_stack(batch) + if self.augmentations is None: + pass + elif isinstance(self.augmentations, Mapping): + for key, fn in self.augmentations.items(): + batch[key] = fn(batch[key]) + elif isinstance(self.augmentations, Callable): + batch = self.augmentations(batch) + else: + raise RuntimeError(f"Could not apply augmentations of type {type(self.augmentations)}.") + if self.adapter is not None: batch = self.adapter(batch, stage=self.stage) diff --git a/bayesflow/datasets/offline_dataset.py b/bayesflow/datasets/offline_dataset.py index 51f2b51f7..075e5135b 100644 --- a/bayesflow/datasets/offline_dataset.py +++ b/bayesflow/datasets/offline_dataset.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping +from collections.abc import Mapping, Callable import numpy as np @@ -23,8 +23,37 @@ def __init__( num_samples: int = None, *, stage: str = "training", + augmentations: Mapping[str, Callable] | Callable = None, **kwargs, ): + """ + Initialize an OfflineDataset instance for offline training with optional data augmentations. + + Parameters + ---------- + data : Mapping[str, np.ndarray] + Pre-simulated data stored in a dictionary, where each key maps to a NumPy array. + batch_size : int + Number of samples per batch. + adapter : Adapter or None + Optional adapter to transform the batch. + num_samples : int, optional + Number of samples in the dataset. If None, it will be inferred from the data. + stage : str, default="training" + Current stage (e.g., "training", "validation", etc.) used by the adapter. + augmentations : dict of str to Callable or Callable, optional + Dictionary of augmentation functions to apply to each corresponding key in the batch + or a function to apply to the entire batch (possibly adding new keys). + + If you provide a dictionary of functions, each function should accept one element + of your output batch and return the corresponding transformed element. Otherwise, + your function should accept the entire dictionary output and return a dictionary. + + Note - augmentations are applied before the adapter is called and are generally + transforms that you only want to apply during training. + **kwargs + Additional keyword arguments passed to the base `PyDataset`. + """ super().__init__(**kwargs) self.batch_size = batch_size self.data = data @@ -39,10 +68,29 @@ def __init__( self.indices = np.arange(self.num_samples, dtype="int64") + self.augmentations = augmentations + self.shuffle() def __getitem__(self, item: int) -> dict[str, np.ndarray]: - """Get a batch of pre-simulated data""" + """ + Load a batch of data from disk. + + Parameters + ---------- + item : int + Index of the batch to retrieve. + + Returns + ------- + dict of str to np.ndarray + A batch of loaded (and optionally augmented/adapted) data. + + Raises + ------ + IndexError + If the requested batch index is out of range. + """ if not 0 <= item < self.num_batches: raise IndexError(f"Index {item} is out of bounds for dataset with {self.num_batches} batches.") @@ -54,6 +102,16 @@ def __getitem__(self, item: int) -> dict[str, np.ndarray]: for key, value in self.data.items() } + if self.augmentations is None: + pass + elif isinstance(self.augmentations, Mapping): + for key, fn in self.augmentations.items(): + batch[key] = fn(batch[key]) + elif isinstance(self.augmentations, Callable): + batch = self.augmentations(batch) + else: + raise RuntimeError(f"Could not apply augmentations of type {type(self.augmentations)}.") + if self.adapter is not None: batch = self.adapter(batch, stage=self.stage) diff --git a/bayesflow/datasets/online_dataset.py b/bayesflow/datasets/online_dataset.py index 18701f70e..8cb0777a0 100644 --- a/bayesflow/datasets/online_dataset.py +++ b/bayesflow/datasets/online_dataset.py @@ -1,3 +1,5 @@ +from collections.abc import Mapping, Callable + import keras import numpy as np @@ -7,7 +9,7 @@ class OnlineDataset(keras.utils.PyDataset): """ - A dataset that is generated on-the-fly. + A dataset that generates simulations on-the-fly. """ def __init__( @@ -18,8 +20,37 @@ def __init__( adapter: Adapter | None, *, stage: str = "training", + augmentations: Mapping[str, Callable] | Callable = None, **kwargs, ): + """ + Initialize an OnlineDataset instance for infinite stream training. + + Parameters + ---------- + simulator : Simulator + A simulator object with a `.sample(batch_shape)` method to generate data. + batch_size : int + Number of samples per batch. + num_batches : int + Total number of batches in the dataset. + adapter : Adapter or None + Optional adapter to transform the simulated batch. + stage : str, default="training" + Current stage (e.g., "training", "validation", etc.) used by the adapter. + augmentations : dict of str to Callable or Callable, optional + Dictionary of augmentation functions to apply to each corresponding key in the batch + or a function to apply to the entire batch (possibly adding new keys). + + If you provide a dictionary of functions, each function should accept one element + of your output batch and return the corresponding transformed element. Otherwise, + your function should accept the entire dictionary output and return a dictionary. + + Note - augmentations are applied before the adapter is called and are generally + transforms that you only want to apply during training. + **kwargs + Additional keyword arguments passed to the base `PyDataset`. + """ super().__init__(**kwargs) self.batch_size = batch_size @@ -27,10 +58,34 @@ def __init__( self.adapter = adapter self.simulator = simulator self.stage = stage + self.augmentations = augmentations def __getitem__(self, item: int) -> dict[str, np.ndarray]: + """ + Generate one batch of data. + + Parameters + ---------- + item : int + Index of the batch. Required by signature, but not used. + + Returns + ------- + dict of str to np.ndarray + A batch of simulated (and optionally augmented/adapted) data. + """ batch = self.simulator.sample((self.batch_size,)) + if self.augmentations is None: + pass + elif isinstance(self.augmentations, Mapping): + for key, fn in self.augmentations.items(): + batch[key] = fn(batch[key]) + elif isinstance(self.augmentations, Callable): + batch = self.augmentations(batch) + else: + raise RuntimeError(f"Could not apply augmentations of type {type(self.augmentations)}.") + if self.adapter is not None: batch = self.adapter(batch, stage=self.stage) diff --git a/bayesflow/datasets/rounds_dataset.py b/bayesflow/datasets/rounds_dataset.py deleted file mode 100644 index b6c59336c..000000000 --- a/bayesflow/datasets/rounds_dataset.py +++ /dev/null @@ -1,66 +0,0 @@ -import keras -import numpy as np - -from bayesflow.adapters import Adapter -from bayesflow.simulators.simulator import Simulator -from bayesflow.utils import logging - - -class RoundsDataset(keras.utils.PyDataset): - """ - A dataset that is generated on-the-fly at the beginning of every n-th epoch. - """ - - def __init__( - self, - simulator: Simulator, - batch_size: int, - num_batches: int, - epochs_per_round: int, - adapter: Adapter | None, - *, - stage: str = "training", - **kwargs, - ): - super().__init__(**kwargs) - - self.batches = None - self._num_batches = num_batches - self.batch_size = batch_size - self.adapter = adapter - self.epoch = 0 - self.stage = stage - - if epochs_per_round == 1: - logging.warning( - "Using `RoundsDataset` with `epochs_per_round=1` is equivalent to fully online training. " - "Use an `OnlineDataset` instead for best performance." - ) - - self.epochs_per_round = epochs_per_round - - self.simulator = simulator - - self.regenerate() - - def __getitem__(self, item: int) -> dict[str, np.ndarray]: - """Get a batch of pre-simulated data""" - batch = self.batches[item] - - if self.adapter is not None: - batch = self.adapter(batch, stage=self.stage) - - return batch - - @property - def num_batches(self) -> int: - return self._num_batches - - def on_epoch_end(self) -> None: - self.epoch += 1 - if self.epoch % self.epochs_per_round == 0: - self.regenerate() - - def regenerate(self) -> None: - """Sample new batches of data from the joint distribution unconditionally""" - self.batches = [self.simulator.sample((self.batch_size,)) for _ in range(self.batches_per_epoch)] diff --git a/bayesflow/workflows/basic_workflow.py b/bayesflow/workflows/basic_workflow.py index 4e21c9e1e..c6d0fce52 100644 --- a/bayesflow/workflows/basic_workflow.py +++ b/bayesflow/workflows/basic_workflow.py @@ -674,6 +674,7 @@ def fit_offline( batch_size: int = 32, keep_optimizer: bool = False, validation_data: Mapping[str, np.ndarray] | int = None, + augmentations: Mapping[str, Callable] | Callable = None, **kwargs, ) -> keras.callbacks.History: """ @@ -698,6 +699,16 @@ def fit_offline( A dictionary containing validation data. If an integer is provided, that number of validation samples will be generated (if supported). By default, no validation data is used. + augmentations : dict of str to Callable or Callable, optional + Dictionary of augmentation functions to apply to each corresponding key in the batch + or a function to apply to the entire batch (possibly adding new keys). + + If you provide a dictionary of functions, each function should accept one element + of your output batch and return the corresponding transformed element. Otherwise, + your function should accept the entire dictionary output and return a dictionary. + + Note - augmentations are applied before the adapter is called and are generally + transforms that you only want to apply during training. **kwargs : dict, optional Additional keyword arguments passed to the underlying `_fit` method. @@ -709,7 +720,7 @@ def fit_offline( metric evolution over epochs. """ - dataset = OfflineDataset(data=data, batch_size=batch_size, adapter=self.adapter) + dataset = OfflineDataset(data=data, batch_size=batch_size, adapter=self.adapter, augmentations=augmentations) return self._fit( dataset, epochs, strategy="online", keep_optimizer=keep_optimizer, validation_data=validation_data, **kwargs @@ -722,6 +733,7 @@ def fit_online( batch_size: int = 32, keep_optimizer: bool = False, validation_data: Mapping[str, np.ndarray] | int = None, + augmentations: Mapping[str, Callable] | Callable = None, **kwargs, ) -> keras.callbacks.History: """ @@ -743,6 +755,16 @@ def fit_online( A dictionary containing validation data. If an integer is provided, that number of validation samples will be generated (if supported). By default, no validation data is used. + augmentations : dict of str to Callable or Callable, optional + Dictionary of augmentation functions to apply to each corresponding key in the batch + or a function to apply to the entire batch (possibly adding new keys). + + If you provide a dictionary of functions, each function should accept one element + of your output batch and return the corresponding transformed element. Otherwise, + your function should accept the entire dictionary output and return a dictionary. + + Note - augmentations are applied before the adapter is called and are generally + transforms that you only want to apply during training. **kwargs : dict, optional Additional keyword arguments passed to the underlying `_fit` method. @@ -755,7 +777,11 @@ def fit_online( """ dataset = OnlineDataset( - simulator=self.simulator, batch_size=batch_size, num_batches=num_batches_per_epoch, adapter=self.adapter + simulator=self.simulator, + batch_size=batch_size, + num_batches=num_batches_per_epoch, + adapter=self.adapter, + augmentations=augmentations, ) return self._fit( @@ -771,6 +797,7 @@ def fit_disk( epochs: int = 100, keep_optimizer: bool = False, validation_data: Mapping[str, np.ndarray] | int = None, + augmentations: Mapping[str, Callable] | Callable = None, **kwargs, ) -> keras.callbacks.History: """ @@ -798,6 +825,16 @@ def fit_disk( A dictionary containing validation data. If an integer is provided, that number of validation samples will be generated (if supported). By default, no validation data is used. + augmentations : dict of str to Callable or Callable, optional + Dictionary of augmentation functions to apply to each corresponding key in the batch + or a function to apply to the entire batch (possibly adding new keys). + + If you provide a dictionary of functions, each function should accept one element + of your output batch and return the corresponding transformed element. Otherwise, + your function should accept the entire dictionary output and return a dictionary. + + Note - augmentations are applied before the adapter is called and are generally + transforms that you only want to apply during training. **kwargs : dict, optional Additional keyword arguments passed to the underlying `_fit` method. @@ -809,7 +846,14 @@ def fit_disk( metric evolution over epochs. """ - dataset = DiskDataset(root=root, pattern=pattern, batch_size=batch_size, load_fn=load_fn, adapter=self.adapter) + dataset = DiskDataset( + root=root, + pattern=pattern, + batch_size=batch_size, + load_fn=load_fn, + adapter=self.adapter, + augmentations=augmentations, + ) return self._fit( dataset, epochs, strategy="online", keep_optimizer=keep_optimizer, validation_data=validation_data, **kwargs