Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
e4e6da4
Merge branch 'dev' of https://github.com/bayesflow-org/bayesflow into…
stefanradev93 Jun 22, 2025
7b27f14
Merge branch 'dev' of https://github.com/bayesflow-org/bayesflow into…
stefanradev93 Jul 1, 2025
c9feff2
Merge branch 'dev' of https://github.com/bayesflow-org/bayesflow into…
stefanradev93 Jul 2, 2025
0ea79d7
Merge branch 'dev' of https://github.com/bayesflow-org/bayesflow into…
stefanradev93 Jul 11, 2025
766c57f
Remove stateful adapter features
stefanradev93 Jul 11, 2025
bfdf194
Fix tests
stefanradev93 Jul 11, 2025
85b884b
Fix typo
elseml Jul 11, 2025
82e7345
Remove nnpe from adapter
stefanradev93 Jul 12, 2025
8483fb8
Bring back notes [skip ci]
stefanradev93 Jul 12, 2025
b66a553
Remove unncessary restriction to kwargs only [skip ci]
stefanradev93 Jul 12, 2025
12d5ebc
Remove old super call [skip ci]
stefanradev93 Jul 12, 2025
6c4bcfc
Robustify type [skip ci]
stefanradev93 Jul 12, 2025
bd41d96
remove standardize from multimodal sim notebook [no ci]
vpratz Jul 13, 2025
afcae17
add draft module docstring to augmentations module [no ci]
vpratz Jul 13, 2025
e40624d
adapt and run neurocognitive modeling notebook [no ci]
vpratz Jul 13, 2025
f5fba59
adapt cCM playground notebook [no ci]
vpratz Jul 13, 2025
3e30aa5
adapt signature of Adapter.standardize
vpratz Jul 13, 2025
62b3a43
add parameters missed in previous commit
vpratz Jul 13, 2025
fe47bbf
Minor NNPE polishing
elseml Jul 14, 2025
9e37c44
remove stage in docstring from OnlineDataset
vpratz Jul 14, 2025
27171e3
Merge remote-tracking branch 'upstream/dev' into stateless-adapters
vpratz Jul 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 12 additions & 18 deletions bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Callable, MutableSequence, Sequence, Mapping
from collections.abc import Callable, MutableSequence, Sequence

import numpy as np

Expand Down Expand Up @@ -87,16 +87,14 @@ def get_config(self) -> dict:
return serialize(config)

def forward(
self, data: dict[str, any], *, stage: str = "inference", log_det_jac: bool = False, **kwargs
self, data: dict[str, any], *, log_det_jac: bool = False, **kwargs
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
"""Apply the transforms in the forward direction.

Parameters
----------
data : dict
data : dict[str, any]
The data to be transformed.
stage : str, one of ["training", "validation", "inference"]
The stage the function is called in.
log_det_jac: bool, optional
Whether to return the log determinant of the Jacobian of the transforms.
**kwargs : dict
Expand All @@ -110,28 +108,26 @@ def forward(
data = data.copy()
if not log_det_jac:
for transform in self.transforms:
data = transform(data, stage=stage, **kwargs)
data = transform(data, **kwargs)
return data

log_det_jac = {}
for transform in self.transforms:
transformed_data = transform(data, stage=stage, **kwargs)
transformed_data = transform(data, **kwargs)
log_det_jac = transform.log_det_jac(data, log_det_jac, **kwargs)
data = transformed_data

return data, log_det_jac

def inverse(
self, data: dict[str, np.ndarray], *, stage: str = "inference", log_det_jac: bool = False, **kwargs
self, data: dict[str, any], *, log_det_jac: bool = False, **kwargs
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
"""Apply the transforms in the inverse direction.

Parameters
----------
data : dict
data : dict[str, any]
The data to be transformed.
stage : str, one of ["training", "validation", "inference"]
The stage the function is called in.
log_det_jac: bool, optional
Whether to return the log determinant of the Jacobian of the transforms.
**kwargs : dict
Expand All @@ -145,18 +141,18 @@ def inverse(
data = data.copy()
if not log_det_jac:
for transform in reversed(self.transforms):
data = transform(data, stage=stage, inverse=True, **kwargs)
data = transform(data, inverse=True, **kwargs)
return data

log_det_jac = {}
for transform in reversed(self.transforms):
data = transform(data, stage=stage, inverse=True, **kwargs)
data = transform(data, inverse=True, **kwargs)
log_det_jac = transform.log_det_jac(data, log_det_jac, inverse=True, **kwargs)

return data, log_det_jac

def __call__(
self, data: Mapping[str, any], *, inverse: bool = False, stage="inference", **kwargs
self, data: dict[str, any], *, inverse: bool = False, **kwargs
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
"""Apply the transforms in the given direction.

Expand All @@ -166,8 +162,6 @@ def __call__(
The data to be transformed.
inverse : bool, optional
If False, apply the forward transform, else apply the inverse transform (default False).
stage : str, one of ["training", "validation", "inference"]
The stage the function is called in.
**kwargs
Additional keyword arguments passed to each transform.

Expand All @@ -177,9 +171,9 @@ def __call__(
The transformed data or tuple of transformed data and log determinant of the Jacobian.
"""
if inverse:
return self.inverse(data, stage=stage, **kwargs)
return self.inverse(data, **kwargs)

return self.forward(data, stage=stage, **kwargs)
return self.forward(data, **kwargs)

def __repr__(self):
result = ""
Expand Down
32 changes: 3 additions & 29 deletions bayesflow/adapters/transforms/nnpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,6 @@ class NNPE(ElementwiseTransform):
The seed for the random number generator. If None, a random seed is used. Used instead of np.random.Generator
here to enable easy serialization.

Notes
-----
The spike-and-slab distribution consists of a mixture of a Normal distribution (spike) and Cauchy distribution
(slab), which are applied based on a Bernoulli random variable with p=0.5.

The scales of the spike and slab distributions can be set manually, or they are automatically determined by scaling
the default scales of [1] (which expect standardized data) by the standard deviation of the input data.
For automatic determination, the standard deviation is determined either globally (if `per_dimension=False`) or per
dimension of the last axis of the input data (if `per_dimension=True`). Note that automatic scale determination is
applied batch-wise in the forward method, which means that determined scales can vary between batches due to varying
standard deviations in the batch input data.

The original implementation in [1] can be recovered by applying the following settings on standardized data:
- `spike_scale=0.01`
- `slab_scale=0.25`
- `per_dimension=False`

Examples
--------
>>> adapter = bf.Adapter().nnpe(["x"])
Expand Down Expand Up @@ -136,27 +119,18 @@ def _resolve_scale(
raise ValueError(f"{name}: expected scalar, got array of shape {arr.shape}")
return arr

def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.ndarray:
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
"""
Add spike‐and‐slab noise to `data` during training, using automatic scale determination if not provided (see
“Notes” section of the class docstring for details).
Add spike‐and‐slab noise to `data` using automatic scale determination if not provided.
See “Notes” section of the class docstring for details).

Parameters
----------
data : np.ndarray
Input array to be perturbed.
stage : str, default='inference'
If 'training', noise is added; else data is returned unchanged.
**kwargs
Unused keyword arguments.

Returns
-------
np.ndarray
Noisy data when `stage` is 'training', otherwise the original input.
"""
if stage != "training":
return data

# Check data validity
if not np.all(np.isfinite(data)):
Expand Down
98 changes: 13 additions & 85 deletions bayesflow/adapters/transforms/standardize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from collections.abc import Sequence
import warnings

import numpy as np

from bayesflow.utils.serialization import serializable, serialize
Expand All @@ -11,120 +8,51 @@
@serializable("bayesflow.adapters")
class Standardize(ElementwiseTransform):
"""
Transform that when applied standardizes data using typical z-score standardization
i.e. for some unstandardized data x the standardized version z would be
Transform that when applied standardizes data using typical z-score standardization with
fixed means and std, i.e. for some unstandardized data x the standardized version z would be

>>> z = (x - mean(x)) / std(x)

Important: Ensure dynamic standarization (employed by BayesFlow approximators) has been
turned off when using this transform.

Parameters
----------
mean : int or float, optional
Specify a mean if known but will be estimated from data when not provided
std : int or float, optional
Specify a standard devation if known but will be estimated from data when not provided
axis : int, optional
A specific axis along which standardization should take place. By default
standardization happens individually for each dimension
momentum : float in (0,1)
The momentum during training
mean : int or float
Specifies the mean (location) of the transform.
std : int or float
Specifies the standard deviation (scale) of the transform.

Examples
--------
1) Standardize all variables using their individually estimated mean and stds.

>>> adapter = (
bf.adapters.Adapter()
.standardize()
)


2) Standardize all with same known mean and std.

>>> adapter = (
bf.adapters.Adapter()
.standardize(mean = 5, sd = 10)
)


3) Mix of fixed and estimated means/stds. Suppose we have priors for "beta" and "sigma" where we
know the means and stds. However for all other variables, the means and stds are unknown.
Then standardize should be used in several stages specifying which variables to include or exclude.

>>> adapter = (
bf.adapters.Adapter()
# mean fixed, std estimated
.standardize(include = "beta", mean = 1)
# both mean and SD fixed
.standardize(include = "sigma", mean = 0.6, sd = 3)
# both means and stds estimated for all other variables
.standardize(exclude = ["beta", "sigma"])
)
bf.adapters.Adapter().standardize(include="beta", mean=5, std=10)
"""

def __init__(
self,
mean: int | float | np.ndarray = None,
std: int | float | np.ndarray = None,
axis: int | Sequence[int] = None,
momentum: float | None = 0.99,
mean: int | float | np.ndarray,
std: int | float | np.ndarray,
):
super().__init__()

if mean is None or std is None:
warnings.warn(
"Dynamic standardization is deprecated and will be removed in later versions."
"Instead, use the standardize argument of the approximator / workflow instance or provide "
"fixed mean and std arguments. You may incur some redundant computations if you keep this transform.",
FutureWarning,
)

self.mean = mean
self.std = std

if isinstance(axis, Sequence):
# numpy hates lists
axis = tuple(axis)
self.axis = axis
self.momentum = momentum

def get_config(self) -> dict:
config = {
"mean": self.mean,
"std": self.std,
"axis": self.axis,
"momentum": self.momentum,
}
return serialize(config)

def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.ndarray:
if self.axis is None:
self.axis = tuple(range(data.ndim - 1))

if self.mean is None:
self.mean = np.mean(data, axis=self.axis, keepdims=True)
else:
if self.momentum is not None and stage == "training":
self.mean = self.momentum * self.mean + (1.0 - self.momentum) * np.mean(
data, axis=self.axis, keepdims=True
)

if self.std is None:
self.std = np.std(data, axis=self.axis, keepdims=True, ddof=1)
else:
if self.momentum is not None and stage == "training":
self.std = self.momentum * self.std + (1.0 - self.momentum) * np.std(
data, axis=self.axis, keepdims=True, ddof=1
)

def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
mean = np.broadcast_to(self.mean, data.shape)
std = np.broadcast_to(self.std, data.shape)

return (data - mean) / std

def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
if self.mean is None or self.std is None:
raise RuntimeError("Cannot call `inverse` before calling `forward` at least once.")

mean = np.broadcast_to(self.mean, data.shape)
std = np.broadcast_to(self.std, data.shape)

Expand Down
4 changes: 2 additions & 2 deletions bayesflow/approximators/continuous_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def _prepare_data(
Handles inputs containing only conditions, only inference_variables, or both.
Optionally tracks log-determinant Jacobian (ldj) of transformations.
"""
adapted = self.adapter(data, strict=False, stage="inference", log_det_jac=log_det_jac, **kwargs)
adapted = self.adapter(data, strict=False, log_det_jac=log_det_jac, **kwargs)

if log_det_jac:
data, ldj = adapted
Expand Down Expand Up @@ -565,7 +565,7 @@ def summarize(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
if self.summary_network is None:
raise ValueError("A summary network is required to compute summaries.")

data_adapted = self.adapter(data, strict=False, stage="inference", **kwargs)
data_adapted = self.adapter(data, strict=False, **kwargs)
if "summary_variables" not in data_adapted or data_adapted["summary_variables"] is None:
raise ValueError("Summary variables are required to compute summaries.")

Expand Down
4 changes: 2 additions & 2 deletions bayesflow/approximators/model_comparison_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def predict(
probs = not logits

# Apply adapter transforms to raw simulated / real quantities
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
conditions = self.adapter(conditions, strict=False, **kwargs)

# Ensure only keys relevant for sampling are present in the conditions dictionary
conditions = {k: v for k, v in conditions.items() if k in self.CONDITION_KEYS}
Expand Down Expand Up @@ -429,7 +429,7 @@ def summarize(self, data: Mapping[str, np.ndarray], **kwargs) -> np.ndarray:
if self.summary_network is None:
raise ValueError("A summary network is required to compute summaries.")

data_adapted = self.adapter(data, strict=False, stage="inference", **kwargs)
data_adapted = self.adapter(data, strict=False, **kwargs)
if "summary_variables" not in data_adapted or data_adapted["summary_variables"] is None:
raise ValueError("Summary variables are required to compute summaries.")

Expand Down
6 changes: 1 addition & 5 deletions bayesflow/datasets/disk_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
batch_size: int,
load_fn: Callable = None,
adapter: Adapter | None,
stage: str = "training",
augmentations: Callable | Mapping[str, Callable] | Sequence[Callable] = None,
shuffle: bool = True,
**kwargs,
Expand All @@ -56,8 +55,6 @@
Function to load a single file into a sample. Defaults to `pickle_load`.
adapter : Adapter or None
Optional adapter to transform the loaded batch.
stage : str, default="training"
Current stage (e.g., "training", "validation", etc.) used by the adapter.
augmentations : Callable or Mapping[str, Callable] or Sequence[Callable], optional
A single augmentation function, dictionary of augmentation functions, or sequence of augmentation functions
to apply to the batch.
Expand All @@ -80,7 +77,6 @@
self.load_fn = load_fn or pickle_load
self.adapter = adapter
self.files = list(map(str, self.root.glob(pattern)))
self.stage = stage

self.augmentations = augmentations or []
self._shuffle = shuffle
Expand Down Expand Up @@ -111,7 +107,7 @@
raise RuntimeError(f"Could not apply augmentations of type {type(self.augmentations)}.")

if self.adapter is not None:
batch = self.adapter(batch, stage=self.stage)
batch = self.adapter(batch)

Check warning on line 110 in bayesflow/datasets/disk_dataset.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/datasets/disk_dataset.py#L110

Added line #L110 was not covered by tests

return batch

Expand Down
Loading