Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion bayesflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from .offline_dataset import OfflineDataset
from .online_dataset import OnlineDataset
from .disk_dataset import DiskDataset
from .rounds_dataset import RoundsDataset

from ..utils._docs import _add_imports_to_all

Expand Down
58 changes: 52 additions & 6 deletions bayesflow/datasets/disk_dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import keras
import numpy as np
from collections.abc import Mapping, Callable

import os
import pathlib as pl

import numpy as np

import keras

from bayesflow.adapters import Adapter
from bayesflow.utils import tree_stack, pickle_load

Expand All @@ -29,11 +33,43 @@
*,
pattern: str = "*.pkl",
batch_size: int,
load_fn: callable = None,
load_fn: Callable = None,
adapter: Adapter | None,
stage: str = "training",
augmentations: Mapping[str, Callable] | Callable = None,
**kwargs,
):
"""
Initialize a DiskDataset instance for offline training using a set of simulations that
do not fit on disk.
Parameters
----------
root : os.PathLike
Root directory containing the sample files.
pattern : str, default="*.pkl"
Glob pattern to match sample files.
batch_size : int
Number of samples per batch.
load_fn : Callable, optional
Function to load a single file into a sample. Defaults to `pickle_load`.
adapter : Adapter or None
Optional adapter to transform the loaded batch.
stage : str, default="training"
Current stage (e.g., "training", "validation", etc.) used by the adapter.
augmentations : dict of str to Callable or Callable, optional
Dictionary of augmentation functions to apply to each corresponding key in the batch
or a function to apply to the entire batch (possibly adding new keys).
If you provide a dictionary of functions, each function should accept one element
of your output batch and return the corresponding transformed element. Otherwise,
your function should accept the entire dictionary output and return a dictionary.
Note - augmentations are applied before the adapter is called and are generally
transforms that you only want to apply during training.
**kwargs
Additional keyword arguments passed to the base `PyDataset`.
"""
super().__init__(**kwargs)
self.batch_size = batch_size
self.root = pl.Path(root)
Expand All @@ -42,6 +78,8 @@
self.files = list(map(str, self.root.glob(pattern)))
self.stage = stage

self.augmentations = augmentations

Check warning on line 81 in bayesflow/datasets/disk_dataset.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/datasets/disk_dataset.py#L81

Added line #L81 was not covered by tests

self.shuffle()

def __getitem__(self, item) -> dict[str, np.ndarray]:
Expand All @@ -50,12 +88,20 @@

files = self.files[item * self.batch_size : (item + 1) * self.batch_size]

batch = []
for file in files:
batch.append(self.load_fn(file))
batch = [self.load_fn(file) for file in files]

Check warning on line 91 in bayesflow/datasets/disk_dataset.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/datasets/disk_dataset.py#L91

Added line #L91 was not covered by tests

batch = tree_stack(batch)

if self.augmentations is None:
pass
elif isinstance(self.augmentations, Mapping):
for key, fn in self.augmentations.items():
batch[key] = fn(batch[key])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only works for shallow batch trees. However, I think that is fine, since this is the only use-case so far.

elif isinstance(self.augmentations, Callable):
batch = self.augmentations(batch)

Check warning on line 101 in bayesflow/datasets/disk_dataset.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/datasets/disk_dataset.py#L95-L101

Added lines #L95 - L101 were not covered by tests
else:
raise RuntimeError(f"Could not apply augmentations of type {type(self.augmentations)}.")

Check warning on line 103 in bayesflow/datasets/disk_dataset.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/datasets/disk_dataset.py#L103

Added line #L103 was not covered by tests

if self.adapter is not None:
batch = self.adapter(batch, stage=self.stage)

Expand Down
62 changes: 60 additions & 2 deletions bayesflow/datasets/offline_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Mapping
from collections.abc import Mapping, Callable

import numpy as np

Expand All @@ -23,8 +23,37 @@
num_samples: int = None,
*,
stage: str = "training",
augmentations: Mapping[str, Callable] | Callable = None,
**kwargs,
):
"""
Initialize an OfflineDataset instance for offline training with optional data augmentations.

Parameters
----------
data : Mapping[str, np.ndarray]
Pre-simulated data stored in a dictionary, where each key maps to a NumPy array.
batch_size : int
Number of samples per batch.
adapter : Adapter or None
Optional adapter to transform the batch.
num_samples : int, optional
Number of samples in the dataset. If None, it will be inferred from the data.
stage : str, default="training"
Current stage (e.g., "training", "validation", etc.) used by the adapter.
augmentations : dict of str to Callable or Callable, optional
Dictionary of augmentation functions to apply to each corresponding key in the batch
or a function to apply to the entire batch (possibly adding new keys).

If you provide a dictionary of functions, each function should accept one element
of your output batch and return the corresponding transformed element. Otherwise,
your function should accept the entire dictionary output and return a dictionary.

Note - augmentations are applied before the adapter is called and are generally
transforms that you only want to apply during training.
**kwargs
Additional keyword arguments passed to the base `PyDataset`.
"""
super().__init__(**kwargs)
self.batch_size = batch_size
self.data = data
Expand All @@ -39,10 +68,29 @@

self.indices = np.arange(self.num_samples, dtype="int64")

self.augmentations = augmentations

self.shuffle()

def __getitem__(self, item: int) -> dict[str, np.ndarray]:
"""Get a batch of pre-simulated data"""
"""
Load a batch of data from disk.

Parameters
----------
item : int
Index of the batch to retrieve.

Returns
-------
dict of str to np.ndarray
A batch of loaded (and optionally augmented/adapted) data.

Raises
------
IndexError
If the requested batch index is out of range.
"""
if not 0 <= item < self.num_batches:
raise IndexError(f"Index {item} is out of bounds for dataset with {self.num_batches} batches.")

Expand All @@ -54,6 +102,16 @@
for key, value in self.data.items()
}

if self.augmentations is None:
pass
elif isinstance(self.augmentations, Mapping):
for key, fn in self.augmentations.items():
batch[key] = fn(batch[key])
elif isinstance(self.augmentations, Callable):
batch = self.augmentations(batch)

Check warning on line 111 in bayesflow/datasets/offline_dataset.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/datasets/offline_dataset.py#L107-L111

Added lines #L107 - L111 were not covered by tests
else:
raise RuntimeError(f"Could not apply augmentations of type {type(self.augmentations)}.")

Check warning on line 113 in bayesflow/datasets/offline_dataset.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/datasets/offline_dataset.py#L113

Added line #L113 was not covered by tests

if self.adapter is not None:
batch = self.adapter(batch, stage=self.stage)

Expand Down
57 changes: 56 additions & 1 deletion bayesflow/datasets/online_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections.abc import Mapping, Callable

import keras
import numpy as np

Expand All @@ -7,7 +9,7 @@

class OnlineDataset(keras.utils.PyDataset):
"""
A dataset that is generated on-the-fly.
A dataset that generates simulations on-the-fly.
"""

def __init__(
Expand All @@ -18,19 +20,72 @@
adapter: Adapter | None,
*,
stage: str = "training",
augmentations: Mapping[str, Callable] | Callable = None,
**kwargs,
):
"""
Initialize an OnlineDataset instance for infinite stream training.

Parameters
----------
simulator : Simulator
A simulator object with a `.sample(batch_shape)` method to generate data.
batch_size : int
Number of samples per batch.
num_batches : int
Total number of batches in the dataset.
adapter : Adapter or None
Optional adapter to transform the simulated batch.
stage : str, default="training"
Current stage (e.g., "training", "validation", etc.) used by the adapter.
augmentations : dict of str to Callable or Callable, optional
Dictionary of augmentation functions to apply to each corresponding key in the batch
or a function to apply to the entire batch (possibly adding new keys).

If you provide a dictionary of functions, each function should accept one element
of your output batch and return the corresponding transformed element. Otherwise,
your function should accept the entire dictionary output and return a dictionary.

Note - augmentations are applied before the adapter is called and are generally
transforms that you only want to apply during training.
**kwargs
Additional keyword arguments passed to the base `PyDataset`.
"""
super().__init__(**kwargs)

self.batch_size = batch_size
self._num_batches = num_batches
self.adapter = adapter
self.simulator = simulator
self.stage = stage
self.augmentations = augmentations

def __getitem__(self, item: int) -> dict[str, np.ndarray]:
"""
Generate one batch of data.

Parameters
----------
item : int
Index of the batch. Required by signature, but not used.

Returns
-------
dict of str to np.ndarray
A batch of simulated (and optionally augmented/adapted) data.
"""
batch = self.simulator.sample((self.batch_size,))

if self.augmentations is None:
pass
elif isinstance(self.augmentations, Mapping):
for key, fn in self.augmentations.items():
batch[key] = fn(batch[key])
elif isinstance(self.augmentations, Callable):
batch = self.augmentations(batch)

Check warning on line 85 in bayesflow/datasets/online_dataset.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/datasets/online_dataset.py#L81-L85

Added lines #L81 - L85 were not covered by tests
else:
raise RuntimeError(f"Could not apply augmentations of type {type(self.augmentations)}.")

Check warning on line 87 in bayesflow/datasets/online_dataset.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/datasets/online_dataset.py#L87

Added line #L87 was not covered by tests

if self.adapter is not None:
batch = self.adapter(batch, stage=self.stage)

Expand Down
66 changes: 0 additions & 66 deletions bayesflow/datasets/rounds_dataset.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There have been use-cases for this dataset. Why should we remove it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There have never been use cases in amortized inference. It can be used for sequential inference, but if we want to do actual sequential inf, it needs to change dramatically anyways.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see use cases for this (for example, you want to use some "early stopping" while also wanting to avoid overfitting in case you have to train longer). Could we do a quick user survey in one of our channels to check if anyone is using it before we remove it without deprecating it properly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to ask around.

This file was deleted.

Loading