Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Keep,
Log,
MapTransform,
NNPE,
NumpyTransform,
OneHot,
Rename,
Expand Down Expand Up @@ -699,6 +700,34 @@ def map_dtype(self, keys: str | Sequence[str], to_dtype: str):
self.transforms.append(transform)
return self

def nnpe(
self,
keys: str | Sequence[str],
*,
spike_scale: float | None = None,
slab_scale: float | None = None,
seed: int | None = None,
):
"""Append an :py:class:`~transforms.NNPE` transform to the adapter.

Parameters
----------
keys : str or Sequence of str
The names of the variables to transform.
spike_scale : float or None
The scale of the spike (Normal) distribution. Automatically determined if None.
slab_scale : float or None
The scale of the slab (Cauchy) distribution. Automatically determined if None.
seed : int or None
The seed for the random number generator. If None, a random seed is used.
"""
if isinstance(keys, str):
keys = [keys]

transform = MapTransform({key: NNPE(spike_scale=spike_scale, slab_scale=slab_scale, seed=seed) for key in keys})
self.transforms.append(transform)
return self

def one_hot(self, keys: str | Sequence[str], num_classes: int):
"""Append a :py:class:`~transforms.OneHot` transform to the adapter.

Expand Down
1 change: 1 addition & 0 deletions bayesflow/adapters/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .keep import Keep
from .log import Log
from .map_transform import MapTransform
from .nnpe import NNPE
from .numpy_transform import NumpyTransform
from .one_hot import OneHot
from .rename import Rename
Expand Down
97 changes: 97 additions & 0 deletions bayesflow/adapters/transforms/nnpe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import numpy as np

from bayesflow.utils.serialization import serializable, serialize

from .elementwise_transform import ElementwiseTransform


@serializable("bayesflow.adapters")
class NNPE(ElementwiseTransform):
"""Implements noisy neural posterior estimation (NNPE) as described in [1], which adds noise following a
spike-and-slab distribution to the training data as a mild form of data augmentation to robustify against noisy
real-world data (see [1, 2] for benchmarks).

[1] Ward, D., Cannon, P., Beaumont, M., Fasiolo, M., & Schmon, S. (2022). Robust neural posterior estimation and
statistical model criticism. Advances in Neural Information Processing Systems, 35, 33845-33859.
[2] Elsemüller, L., Pratz, V., von Krause, M., Voss, A., Bürkner, P. C., & Radev, S. T. (2025). Does Unsupervised
Domain Adaptation Improve the Robustness of Amortized Bayesian Inference? A Systematic Evaluation. arXiv preprint
arXiv:2502.04949.

Parameters
----------
spike_scale : float or None
The scale of the spike (Normal) distribution. Automatically determined if None (see “Notes” section).
slab_scale : float or None
The scale of the slab (Cauchy) distribution. Automatically determined if None (see “Notes” section).
seed : int or None
The seed for the random number generator. If None, a random seed is used. Used instead of np.random.Generator
here to enable easy serialization.

Notes
-----
The spike-and-slab distribution consists of a mixture of a Normal distribution (spike) and Cauchy distribution
(slab), which are applied based on a Bernoulli random variable with p=0.5.

The scales of the spike and slab distributions can be set manually, or they are automatically determined by scaling
the default scales of [1] (which expect standardized data) by the standard deviation of the input data.

Examples
--------
>>> adapter = bf.Adapter().nnpe(["x"])
"""

DEFAULT_SLAB = 0.25
DEFAULT_SPIKE = 0.01

def __init__(self, *, spike_scale: float | None = None, slab_scale: float | None = None, seed: int | None = None):
super().__init__()
self.spike_scale = spike_scale
self.slab_scale = slab_scale
self.seed = seed
self.rng = np.random.default_rng(seed)

def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.ndarray:
"""
Add spike‐and‐slab noise to `data` during training, using automatic scale determination if not provided (see
“Notes” section of the class docstring for details).

Parameters
----------
data : np.ndarray
Input array to be perturbed.
stage : str, default='inference'
If 'training', noise is added; else data is returned unchanged.
**kwargs
Unused keyword arguments.

Returns
-------
np.ndarray
Noisy data when `stage` is 'training', otherwise the original input.
"""
if stage != "training":
return data

# Check data validity
if not np.all(np.isfinite(data)):
raise ValueError("NNPE.forward: `data` contains NaN or infinite values.")

Check warning on line 77 in bayesflow/adapters/transforms/nnpe.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nnpe.py#L77

Added line #L77 was not covered by tests

# Automatically determine scales if not provided
if self.spike_scale is None or self.slab_scale is None:
data_std = np.std(data)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will lead to different scales for each batch, I'm not sure if this is desirable. If we choose to do this, we should state it more explicitly in the docstring.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an alternative solution:

  • Standardize data with batch_mean and batch_std
  • Add unscaled spikes and slabs
  • Re-scale back with batch batch_mean and batch_std

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for checking! Yes, this is the drawback of that solution. I still think it is preferable to more complex solutions, since the method is more about adding some noise at all rather than adding a very specific amount of noise (the default scales by Ward et al. also seem quite heuristically chosen to me). If you agree on this, I can add some more info in the docstring.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we want to do the automatic scaling globally or per dimension? I think this would be the main difference in what @stefanradev93 proposed and how it is implemented now, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way Lasse explained to me, the approach explicitly wants that scale(original) < scale(transformed). In that case, I think fluctuations between batches are fine, as the downstream Standardize layer (which will be part of approximators) will take care of that.

Copy link
Collaborator

@vpratz vpratz May 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then the question still is how we want to automatically determine the scale, globally or per dimension? If dimensions don't have equal magnitude, we might accidentally erase some of them completely. On the other hand, some dimensions might have zero variation (e.g. in image datasets like MNIST), so we would have to decide how to deal with those...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I would scale dimensionwise.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I implemented it globally following the original NNPE implementation, but agree that dimensionwise scaling would be valuable in many situations and will implement it as an option. I think dimensions with zero variation are not problematic since in that case nothings breaks, there will simply be no noise added. Dimensionwise instead of global scaling will increase the variability of the std calculation between batches though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Only to make sure, please set dimensionwise as the default, and make global scaling the option.

spike_scale = self.spike_scale if self.spike_scale is not None else self.DEFAULT_SPIKE * data_std
slab_scale = self.slab_scale if self.slab_scale is not None else self.DEFAULT_SLAB * data_std

# Apply spike-and-slab noise
mixture_mask = self.rng.binomial(n=1, p=0.5, size=data.shape).astype(bool)
noise_spike = self.rng.standard_normal(size=data.shape) * spike_scale
noise_slab = self.rng.standard_cauchy(size=data.shape) * slab_scale
noise = np.where(mixture_mask, noise_slab, noise_spike)
return data + noise

def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
"""Non-invertible transform."""
return data

def get_config(self) -> dict:
return serialize({"spike_scale": self.spike_scale, "slab_scale": self.slab_scale, "seed": self.seed})
53 changes: 53 additions & 0 deletions tests/test_adapters/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,56 @@ def test_log_det_jac_exceptions(random_data):

# inverse works when concatenation is used after transforms
assert np.allclose(forward_log_det_jac["p"], -inverse_log_det_jac)


def test_nnpe(random_data):
# NNPE cannot be integrated into the adapter fixture and its tests since it modifies the input data
# and therefore breaks existing allclose checks
import numpy as np
from bayesflow.adapters import Adapter

ad = Adapter().nnpe("x1", spike_scale=1.0, slab_scale=1.0, seed=42)
result_training = ad(random_data, stage="training")
result_validation = ad(random_data, stage="validation")
result_inference = ad(random_data, stage="inference")
result_inversed = ad(random_data, inverse=True)
serialized = serialize(ad)
deserialized = deserialize(serialized)
reserialized = serialize(deserialized)

assert keras.tree.lists_to_tuples(serialized) == keras.tree.lists_to_tuples(reserialized)

# check that only x1 is changed
assert "x1" in result_training
assert not np.allclose(result_training["x1"], random_data["x1"])

# all other keys are untouched
for k, v in random_data.items():
if k == "x1":
continue
assert np.allclose(result_training[k], v)

# check that the validation and inference data as well as inversed results are unchanged
for k, v in random_data.items():
assert np.allclose(result_validation[k], v)
assert np.allclose(result_inference[k], v)
assert np.allclose(result_inversed[k], v)

# Test at least one scale is None case (automatic scale determination)
ad_partial = Adapter().nnpe("x2", slab_scale=None, spike_scale=1.0, seed=42)
result_training_partial = ad_partial(random_data, stage="training")
assert not np.allclose(result_training_partial["x2"], random_data["x2"])

# Test both scales and seed are None case (automatic scale determination)
ad_auto = Adapter().nnpe("y1", slab_scale=None, spike_scale=None, seed=None)
result_training_auto = ad_auto(random_data, stage="training")
assert not np.allclose(result_training_auto["y1"], random_data["y1"])
for k, v in random_data.items():
if k == "y1":
continue
assert np.allclose(result_training_auto[k], v)

serialized_auto = serialize(ad_auto)
deserialized_auto = deserialize(serialized_auto)
reserialized_auto = serialize(deserialized_auto)
assert keras.tree.lists_to_tuples(serialized_auto) == keras.tree.lists_to_tuples(serialize(reserialized_auto))