Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Keep,
Log,
MapTransform,
NNPE,
NumpyTransform,
OneHot,
Rename,
Expand Down Expand Up @@ -699,6 +700,43 @@ def map_dtype(self, keys: str | Sequence[str], to_dtype: str):
self.transforms.append(transform)
return self

def nnpe(
self,
keys: str | Sequence[str],
*,
spike_scale: float | None = None,
slab_scale: float | None = None,
per_dimension: bool = True,
seed: int | None = None,
):
"""Append an :py:class:`~transforms.NNPE` transform to the adapter.

Parameters
----------
keys : str or Sequence of str
The names of the variables to transform.
spike_scale : float or np.ndarray or None, default=None
The scale of the spike (Normal) distribution. Automatically determined if None.
slab_scale : float or np.ndarray or None, default=None
The scale of the slab (Cauchy) distribution. Automatically determined if None.
per_dimension : bool, default=True
If true, noise is applied per dimension of the last axis of the input data.
If false, noise is applied globally.
seed : int or None
The seed for the random number generator. If None, a random seed is used.
"""
if isinstance(keys, str):
keys = [keys]

transform = MapTransform(
{
key: NNPE(spike_scale=spike_scale, slab_scale=slab_scale, per_dimension=per_dimension, seed=seed)
for key in keys
}
)
self.transforms.append(transform)
return self

def one_hot(self, keys: str | Sequence[str], num_classes: int):
"""Append a :py:class:`~transforms.OneHot` transform to the adapter.

Expand Down
1 change: 1 addition & 0 deletions bayesflow/adapters/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .keep import Keep
from .log import Log
from .map_transform import MapTransform
from .nnpe import NNPE
from .numpy_transform import NumpyTransform
from .one_hot import OneHot
from .rename import Rename
Expand Down
185 changes: 185 additions & 0 deletions bayesflow/adapters/transforms/nnpe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import numpy as np

from bayesflow.utils.serialization import serializable, serialize

from .elementwise_transform import ElementwiseTransform


@serializable("bayesflow.adapters")
class NNPE(ElementwiseTransform):
"""Implements noisy neural posterior estimation (NNPE) as described in [1], which adds noise following a
spike-and-slab distribution to the training data as a mild form of data augmentation to robustify against noisy
real-world data (see [1, 2] for benchmarks). Adds the options of automatic noise scale determination and
dimensionwise noise application to the original implementation in [1] to provide more flexibility in dealing with
unstandardized and heterogeneous data.

[1] Ward, D., Cannon, P., Beaumont, M., Fasiolo, M., & Schmon, S. (2022). Robust neural posterior estimation and
statistical model criticism. Advances in Neural Information Processing Systems, 35, 33845-33859.
[2] Elsemüller, L., Pratz, V., von Krause, M., Voss, A., Bürkner, P. C., & Radev, S. T. (2025). Does Unsupervised
Domain Adaptation Improve the Robustness of Amortized Bayesian Inference? A Systematic Evaluation. arXiv preprint
arXiv:2502.04949.

Parameters
----------
spike_scale : float or np.ndarray or None, default=None
The scale of the spike (Normal) distribution. Automatically determined if None (see “Notes” section).
Expects a float if `per_dimension=False` or a 1D array of length `data.shape[-1]` if `per_dimension=True`.
slab_scale : float or np.ndarray or None, default=None
The scale of the slab (Cauchy) distribution. Automatically determined if None (see “Notes” section).
Expects a float if `per_dimension=False` or a 1D array of length `data.shape[-1]` if `per_dimension=True`.
per_dimension : bool, default=True
If true, noise is applied per dimension of the last axis of the input data. If false, noise is applied globally.
Thus, if per_dimension=True, any provided scales must be arrays with shape (n_dimensions,) and automatic
scale determination occurs separately per dimension. If per_dimension=False, provided scales must be floats and
automatic scale determination occurs globally. The original implementation in [1] uses global application
(i.e., per_dimension=False), whereas dimensionwise is recommended if the data dimensions are heterogeneous.
seed : int or None
The seed for the random number generator. If None, a random seed is used. Used instead of np.random.Generator
here to enable easy serialization.

Notes
-----
The spike-and-slab distribution consists of a mixture of a Normal distribution (spike) and Cauchy distribution
(slab), which are applied based on a Bernoulli random variable with p=0.5.

The scales of the spike and slab distributions can be set manually, or they are automatically determined by scaling
the default scales of [1] (which expect standardized data) by the standard deviation of the input data.
For automatic determination, the standard deviation is determined either globally (if `per_dimension=False`) or per
dimension of the last axis of the input data (if `per_dimension=True`). Note that automatic scale determination is
applied batch-wise in the forward method, which means that determined scales can vary between batches due to varying
standard deviations in the batch input data.

The original implementation in [1] can be recovered by applying the following settings on standardized data:
- `spike_scale=0.01`
- `slab_scale=0.25`
- `per_dimension=False`

Examples
--------
>>> adapter = bf.Adapter().nnpe(["x"])
"""

DEFAULT_SPIKE = 0.01
DEFAULT_SLAB = 0.25

def __init__(
self,
*,
spike_scale: float | np.ndarray | None = None,
slab_scale: float | np.ndarray | None = None,
per_dimension: bool = True,
seed: int | None = None,
):
super().__init__()
self.spike_scale = spike_scale
self.slab_scale = slab_scale
self.per_dimension = per_dimension
self.seed = seed
self.rng = np.random.default_rng(seed)

def _resolve_scale(
self,
name: str,
passed: float | np.ndarray | None,
default: float,
data: np.ndarray,
) -> np.ndarray | float:
"""
Determine spike/slab scale:
- If passed is None: Automatic determination via default * std(data) (per‐dimension or global).
- Else: validate & cast passed to the correct shape/type.

Parameters
----------
name : str
Identifier for error messages (e.g., 'spike_scale' or 'slab_scale').
passed : float or np.ndarray or None
User-specified scale. If None, compute as default * std(data).
If self.per_dimension is True, this may be a 1D array of length data.shape[-1].
default : float
Default multiplier from [1] to apply to the standard deviation of the data.
data : np.ndarray
Data array to compute standard deviation from.

Returns
-------
float or np.ndarray
The resolved scale, either as a scalar (if per_dimension=False) or an 1D array of length data.shape[-1]
(if per_dimension=True).
"""

# Get std and (expected shape) dimensionwise or globally
if self.per_dimension:
axes = tuple(range(data.ndim - 1))
std = np.std(data, axis=axes)
expected_shape = (data.shape[-1],)
else:
std = np.std(data)
expected_shape = None

Check warning on line 118 in bayesflow/adapters/transforms/nnpe.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nnpe.py#L117-L118

Added lines #L117 - L118 were not covered by tests

# If no scale is passed, determine scale automatically given the dimensionwise or global std
if passed is None:
return default * std

Check warning on line 122 in bayesflow/adapters/transforms/nnpe.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nnpe.py#L122

Added line #L122 was not covered by tests
# If a scale is passed, check if the passed shape matches the expected shape
else:
if self.per_dimension:
arr = np.asarray(passed, dtype=float)
if arr.shape != expected_shape or arr.ndim != 1:
raise ValueError(f"{name}: expected array of shape {expected_shape}, got {arr.shape}")
return arr

Check warning on line 129 in bayesflow/adapters/transforms/nnpe.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nnpe.py#L129

Added line #L129 was not covered by tests
else:
try:
scalar = float(passed)
except Exception:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use more specific exceptions here, I think ValueError and TypeError are the relevant ones.

raise TypeError(f"{name}: expected scalar float, got {type(passed).__name__}")
return scalar

Check warning on line 135 in bayesflow/adapters/transforms/nnpe.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nnpe.py#L131-L135

Added lines #L131 - L135 were not covered by tests

def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.ndarray:
"""
Add spike‐and‐slab noise to `data` during training, using automatic scale determination if not provided (see
“Notes” section of the class docstring for details).

Parameters
----------
data : np.ndarray
Input array to be perturbed.
stage : str, default='inference'
If 'training', noise is added; else data is returned unchanged.
**kwargs
Unused keyword arguments.

Returns
-------
np.ndarray
Noisy data when `stage` is 'training', otherwise the original input.
"""
if stage != "training":
return data

Check warning on line 157 in bayesflow/adapters/transforms/nnpe.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nnpe.py#L157

Added line #L157 was not covered by tests

# Check data validity
if not np.all(np.isfinite(data)):
raise ValueError("NNPE.forward: `data` contains NaN or infinite values.")

Check warning on line 161 in bayesflow/adapters/transforms/nnpe.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nnpe.py#L161

Added line #L161 was not covered by tests

spike_scale = self._resolve_scale("spike_scale", self.spike_scale, self.DEFAULT_SPIKE, data)
slab_scale = self._resolve_scale("slab_scale", self.slab_scale, self.DEFAULT_SLAB, data)

Check warning on line 164 in bayesflow/adapters/transforms/nnpe.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nnpe.py#L164

Added line #L164 was not covered by tests

# Apply spike-and-slab noise
mixture_mask = self.rng.binomial(n=1, p=0.5, size=data.shape).astype(bool)
noise_spike = self.rng.standard_normal(size=data.shape) * spike_scale
noise_slab = self.rng.standard_cauchy(size=data.shape) * slab_scale
noise = np.where(mixture_mask, noise_slab, noise_spike)
return data + noise

Check warning on line 171 in bayesflow/adapters/transforms/nnpe.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nnpe.py#L167-L171

Added lines #L167 - L171 were not covered by tests

def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
"""Non-invertible transform."""
return data

Check warning on line 175 in bayesflow/adapters/transforms/nnpe.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nnpe.py#L175

Added line #L175 was not covered by tests

def get_config(self) -> dict:
return serialize(

Check warning on line 178 in bayesflow/adapters/transforms/nnpe.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nnpe.py#L178

Added line #L178 was not covered by tests
{
"spike_scale": self.spike_scale,
"slab_scale": self.slab_scale,
"per_dimension": self.per_dimension,
"seed": self.seed,
}
)
80 changes: 80 additions & 0 deletions tests/test_adapters/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,83 @@ def test_log_det_jac_exceptions(random_data):

# inverse works when concatenation is used after transforms
assert np.allclose(forward_log_det_jac["p"], -inverse_log_det_jac)


def test_nnpe(random_data):
# NNPE cannot be integrated into the adapter fixture and its tests since it modifies the input data
# and therefore breaks existing allclose checks
import numpy as np
from bayesflow.adapters import Adapter

ad = Adapter().nnpe("x1", spike_scale=1.0, slab_scale=1.0, seed=42)
result_training = ad(random_data, stage="training")
result_validation = ad(random_data, stage="validation")
result_inference = ad(random_data, stage="inference")
result_inversed = ad(random_data, inverse=True)
serialized = serialize(ad)
deserialized = deserialize(serialized)
reserialized = serialize(deserialized)

assert keras.tree.lists_to_tuples(serialized) == keras.tree.lists_to_tuples(reserialized)

# check that only x1 is changed
assert "x1" in result_training
assert not np.allclose(result_training["x1"], random_data["x1"])

# all other keys are untouched
for k, v in random_data.items():
if k == "x1":
continue
assert np.allclose(result_training[k], v)

# check that the validation and inference data as well as inversed results are unchanged
for k, v in random_data.items():
assert np.allclose(result_validation[k], v)
assert np.allclose(result_inference[k], v)
assert np.allclose(result_inversed[k], v)

# Test at least one scale is None case (automatic scale determination)
ad_partial = Adapter().nnpe("x2", slab_scale=None, spike_scale=1.0, seed=42)
result_training_partial = ad_partial(random_data, stage="training")
assert not np.allclose(result_training_partial["x2"], random_data["x2"])

# Test both scales and seed are None case (automatic scale determination)
ad_auto = Adapter().nnpe("y1", slab_scale=None, spike_scale=None, seed=None)
result_training_auto = ad_auto(random_data, stage="training")
assert not np.allclose(result_training_auto["y1"], random_data["y1"])
for k, v in random_data.items():
if k == "y1":
continue
assert np.allclose(result_training_auto[k], v)

serialized_auto = serialize(ad_auto)
deserialized_auto = deserialize(serialized_auto)
reserialized_auto = serialize(deserialized_auto)
assert keras.tree.lists_to_tuples(serialized_auto) == keras.tree.lists_to_tuples(serialize(reserialized_auto))

# Test dimensionwise versus global noise application (per_dimension=True vs per_dimension=False)
# Create data with second dimension having higher variance
data_shape = (32, 16, 1)
rng = np.random.default_rng(42)
zero = np.ones(shape=data_shape)
high = rng.normal(0, 100.0, size=data_shape)
var_data = {"x": np.concatenate([zero, high], axis=-1)}

# Apply dimensionwise and global adapters with automatic slab_scale scale determination
ad_partial_global = Adapter().nnpe("x", spike_scale=0, slab_scale=None, per_dimension=False, seed=42)
ad_partial_dim = Adapter().nnpe("x", spike_scale=[0, 1], slab_scale=None, per_dimension=True, seed=42)
res_dim = ad_partial_dim(var_data, stage="training")
res_glob = ad_partial_global(var_data, stage="training")

# Compute standard deviations of noise per last axis dimension
noise_dim = res_dim["x"] - var_data["x"]
noise_glob = res_glob["x"] - var_data["x"]
std_dim = np.std(noise_dim, axis=(0, 1))
std_glob = np.std(noise_glob, axis=(0, 1))

# Dimensionwise should assign zero noise, global some noise to zero-variance dimension
assert std_dim[0] == 0
assert std_glob[0] > 0
# Both should assign noise to high-variance dimension
assert std_dim[1] > 0
assert std_glob[1] > 0
Loading