Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,7 @@ def nnpe(
*,
spike_scale: float | None = None,
slab_scale: float | None = None,
per_dimension: bool = True,
seed: int | None = None,
):
"""Append an :py:class:`~transforms.NNPE` transform to the adapter.
Expand All @@ -714,17 +715,25 @@ def nnpe(
----------
keys : str or Sequence of str
The names of the variables to transform.
spike_scale : float or None
spike_scale : float or np.ndarray or None, default=None
The scale of the spike (Normal) distribution. Automatically determined if None.
slab_scale : float or None
slab_scale : float or np.ndarray or None, default=None
The scale of the slab (Cauchy) distribution. Automatically determined if None.
per_dimension : bool, default=True
If true, noise is applied per dimension of the last axis of the input data.
If false, noise is applied globally.
seed : int or None
The seed for the random number generator. If None, a random seed is used.
"""
if isinstance(keys, str):
keys = [keys]

transform = MapTransform({key: NNPE(spike_scale=spike_scale, slab_scale=slab_scale, seed=seed) for key in keys})
transform = MapTransform(
{
key: NNPE(spike_scale=spike_scale, slab_scale=slab_scale, per_dimension=per_dimension, seed=seed)
for key in keys
}
)
self.transforms.append(transform)
return self

Expand Down
110 changes: 99 additions & 11 deletions bayesflow/adapters/transforms/nnpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
class NNPE(ElementwiseTransform):
"""Implements noisy neural posterior estimation (NNPE) as described in [1], which adds noise following a
spike-and-slab distribution to the training data as a mild form of data augmentation to robustify against noisy
real-world data (see [1, 2] for benchmarks).
real-world data (see [1, 2] for benchmarks). Adds the options of automatic noise scale determination and
dimensionwise noise application to the original implementation in [1] to provide more flexibility in dealing with
unstandardized and heterogeneous data.

[1] Ward, D., Cannon, P., Beaumont, M., Fasiolo, M., & Schmon, S. (2022). Robust neural posterior estimation and
statistical model criticism. Advances in Neural Information Processing Systems, 35, 33845-33859.
Expand All @@ -19,10 +21,18 @@

Parameters
----------
spike_scale : float or None
spike_scale : float or np.ndarray or None, default=None
The scale of the spike (Normal) distribution. Automatically determined if None (see “Notes” section).
slab_scale : float or None
Expects a float if `per_dimension=False` or a 1D array of length `data.shape[-1]` if `per_dimension=True`.
slab_scale : float or np.ndarray or None, default=None
The scale of the slab (Cauchy) distribution. Automatically determined if None (see “Notes” section).
Expects a float if `per_dimension=False` or a 1D array of length `data.shape[-1]` if `per_dimension=True`.
per_dimension : bool, default=True
If true, noise is applied per dimension of the last axis of the input data. If false, noise is applied globally.
Thus, if per_dimension=True, any provided scales must be arrays with shape (n_dimensions,) and automatic
scale determination occurs separately per dimension. If per_dimension=False, provided scales must be floats and
automatic scale determination occurs globally. The original implementation in [1] uses global application
(i.e., per_dimension=False), whereas dimensionwise is recommended if the data dimensions are heterogeneous.
seed : int or None
The seed for the random number generator. If None, a random seed is used. Used instead of np.random.Generator
here to enable easy serialization.
Expand All @@ -34,22 +44,96 @@

The scales of the spike and slab distributions can be set manually, or they are automatically determined by scaling
the default scales of [1] (which expect standardized data) by the standard deviation of the input data.
For automatic determination, the standard deviation is determined either globally (if `per_dimension=False`) or per
dimension of the last axis of the input data (if `per_dimension=True`). Note that automatic scale determination is
applied batch-wise in the forward method, which means that determined scales can vary between batches due to varying
standard deviations in the batch input data.

The original implementation in [1] can be recovered by applying the following settings on standardized data:
- `spike_scale=0.01`
- `slab_scale=0.25`
- `per_dimension=False`

Examples
--------
>>> adapter = bf.Adapter().nnpe(["x"])
"""

DEFAULT_SLAB = 0.25
DEFAULT_SPIKE = 0.01
DEFAULT_SLAB = 0.25

def __init__(self, *, spike_scale: float | None = None, slab_scale: float | None = None, seed: int | None = None):
def __init__(
self,
*,
spike_scale: float | np.ndarray | None = None,
slab_scale: float | np.ndarray | None = None,
per_dimension: bool = True,
seed: int | None = None,
):
super().__init__()
self.spike_scale = spike_scale
self.slab_scale = slab_scale
self.per_dimension = per_dimension
self.seed = seed
self.rng = np.random.default_rng(seed)

def _resolve_scale(
self,
name: str,
passed: float | np.ndarray | None,
default: float,
data: np.ndarray,
) -> np.ndarray | float:
"""
Determine spike/slab scale:
- If passed is None: Automatic determination via default * std(data) (per‐dimension or global).
- Else: validate & cast passed to the correct shape/type.

Parameters
----------
name : str
Identifier for error messages (e.g., 'spike_scale' or 'slab_scale').
passed : float or np.ndarray or None
User-specified scale. If None, compute as default * std(data).
If self.per_dimension is True, this may be a 1D array of length data.shape[-1].
default : float
Default multiplier from [1] to apply to the standard deviation of the data.
data : np.ndarray
Data array to compute standard deviation from.

Returns
-------
float or np.ndarray
The resolved scale, either as a scalar (if per_dimension=False) or an 1D array of length data.shape[-1]
(if per_dimension=True).
"""

# Get std and (expected shape) dimensionwise or globally
if self.per_dimension:
axes = tuple(range(data.ndim - 1))
std = np.std(data, axis=axes)
expected_shape = (data.shape[-1],)
else:
std = np.std(data)
expected_shape = None

Check warning on line 118 in bayesflow/adapters/transforms/nnpe.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nnpe.py#L117-L118

Added lines #L117 - L118 were not covered by tests

# If no scale is passed, determine scale automatically given the dimensionwise or global std
if passed is None:
return default * std

Check warning on line 122 in bayesflow/adapters/transforms/nnpe.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nnpe.py#L122

Added line #L122 was not covered by tests
# If a scale is passed, check if the passed shape matches the expected shape
else:
if self.per_dimension:
arr = np.asarray(passed, dtype=float)
if arr.shape != expected_shape or arr.ndim != 1:
raise ValueError(f"{name}: expected array of shape {expected_shape}, got {arr.shape}")
return arr

Check warning on line 129 in bayesflow/adapters/transforms/nnpe.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nnpe.py#L129

Added line #L129 was not covered by tests
else:
try:
scalar = float(passed)
except Exception:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use more specific exceptions here, I think ValueError and TypeError are the relevant ones.

raise TypeError(f"{name}: expected scalar float, got {type(passed).__name__}")
return scalar

Check warning on line 135 in bayesflow/adapters/transforms/nnpe.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nnpe.py#L131-L135

Added lines #L131 - L135 were not covered by tests

def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.ndarray:
"""
Add spike‐and‐slab noise to `data` during training, using automatic scale determination if not provided (see
Expand All @@ -70,28 +154,32 @@
Noisy data when `stage` is 'training', otherwise the original input.
"""
if stage != "training":
return data

Check warning on line 157 in bayesflow/adapters/transforms/nnpe.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nnpe.py#L157

Added line #L157 was not covered by tests

# Check data validity
if not np.all(np.isfinite(data)):
raise ValueError("NNPE.forward: `data` contains NaN or infinite values.")

Check warning on line 161 in bayesflow/adapters/transforms/nnpe.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nnpe.py#L161

Added line #L161 was not covered by tests

# Automatically determine scales if not provided
if self.spike_scale is None or self.slab_scale is None:
data_std = np.std(data)
spike_scale = self.spike_scale if self.spike_scale is not None else self.DEFAULT_SPIKE * data_std
slab_scale = self.slab_scale if self.slab_scale is not None else self.DEFAULT_SLAB * data_std
spike_scale = self._resolve_scale("spike_scale", self.spike_scale, self.DEFAULT_SPIKE, data)
slab_scale = self._resolve_scale("slab_scale", self.slab_scale, self.DEFAULT_SLAB, data)

Check warning on line 164 in bayesflow/adapters/transforms/nnpe.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nnpe.py#L164

Added line #L164 was not covered by tests

# Apply spike-and-slab noise
mixture_mask = self.rng.binomial(n=1, p=0.5, size=data.shape).astype(bool)
noise_spike = self.rng.standard_normal(size=data.shape) * spike_scale
noise_slab = self.rng.standard_cauchy(size=data.shape) * slab_scale
noise = np.where(mixture_mask, noise_slab, noise_spike)
return data + noise

Check warning on line 171 in bayesflow/adapters/transforms/nnpe.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nnpe.py#L167-L171

Added lines #L167 - L171 were not covered by tests

def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
"""Non-invertible transform."""
return data

Check warning on line 175 in bayesflow/adapters/transforms/nnpe.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nnpe.py#L175

Added line #L175 was not covered by tests

def get_config(self) -> dict:
return serialize({"spike_scale": self.spike_scale, "slab_scale": self.slab_scale, "seed": self.seed})
return serialize(

Check warning on line 178 in bayesflow/adapters/transforms/nnpe.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nnpe.py#L178

Added line #L178 was not covered by tests
{
"spike_scale": self.spike_scale,
"slab_scale": self.slab_scale,
"per_dimension": self.per_dimension,
"seed": self.seed,
}
)
27 changes: 27 additions & 0 deletions tests/test_adapters/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,30 @@ def test_nnpe(random_data):
deserialized_auto = deserialize(serialized_auto)
reserialized_auto = serialize(deserialized_auto)
assert keras.tree.lists_to_tuples(serialized_auto) == keras.tree.lists_to_tuples(serialize(reserialized_auto))

# Test dimensionwise versus global noise application (per_dimension=True vs per_dimension=False)
# Create data with second dimension having higher variance
data_shape = (32, 16, 1)
rng = np.random.default_rng(42)
zero = np.ones(shape=data_shape)
high = rng.normal(0, 100.0, size=data_shape)
var_data = {"x": np.concatenate([zero, high], axis=-1)}

# Apply dimensionwise and global adapters with automatic slab_scale scale determination
ad_partial_global = Adapter().nnpe("x", spike_scale=0, slab_scale=None, per_dimension=False, seed=42)
ad_partial_dim = Adapter().nnpe("x", spike_scale=[0, 1], slab_scale=None, per_dimension=True, seed=42)
res_dim = ad_partial_dim(var_data, stage="training")
res_glob = ad_partial_global(var_data, stage="training")

# Compute standard deviations of noise per last axis dimension
noise_dim = res_dim["x"] - var_data["x"]
noise_glob = res_glob["x"] - var_data["x"]
std_dim = np.std(noise_dim, axis=(0, 1))
std_glob = np.std(noise_glob, axis=(0, 1))

# Dimensionwise should assign zero noise, global some noise to zero-variance dimension
assert std_dim[0] == 0
assert std_glob[0] > 0
# Both should assign noise to high-variance dimension
assert std_dim[1] > 0
assert std_glob[1] > 0
Loading