Skip to content

Commit 5bbeefe

Browse files
committed
Enable compatibility with bayesflow-org#486 by adjusting scales automatically
1 parent 4931fac commit 5bbeefe

File tree

3 files changed

+57
-22
lines changed

3 files changed

+57
-22
lines changed

bayesflow/adapters/adapter.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -704,8 +704,8 @@ def nnpe(
704704
self,
705705
keys: str | Sequence[str],
706706
*,
707-
slab_scale: float = 0.25,
708-
spike_scale: float = 0.01,
707+
spike_scale: float | None = None,
708+
slab_scale: float | None = None,
709709
seed: int | None = None,
710710
):
711711
"""Append an :py:class:`~transforms.NNPE` transform to the adapter.
@@ -714,17 +714,17 @@ def nnpe(
714714
----------
715715
keys : str or Sequence of str
716716
The names of the variables to transform.
717-
slab_scale : float
718-
The scale of the slab (Cauchy) distribution.
719-
spike_scale : float
720-
The scale of the spike spike (Normal) distribution.
717+
spike_scale : float or None
718+
The scale of the spike (Normal) distribution. Automatically determined if None.
719+
slab_scale : float or None
720+
The scale of the slab (Cauchy) distribution. Automatically determined if None.
721721
seed : int or None
722722
The seed for the random number generator. If None, a random seed is used.
723723
"""
724724
if isinstance(keys, str):
725725
keys = [keys]
726726

727-
transform = MapTransform({key: NNPE(slab_scale=slab_scale, spike_scale=spike_scale, seed=seed) for key in keys})
727+
transform = MapTransform({key: NNPE(spike_scale=spike_scale, slab_scale=slab_scale, seed=seed) for key in keys})
728728
self.transforms.append(transform)
729729
return self
730730

bayesflow/adapters/transforms/nnpe.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,37 +19,41 @@ class NNPE(ElementwiseTransform):
1919
2020
Parameters
2121
----------
22-
slab_scale : float
23-
The scale of the slab (Cauchy) distribution.
24-
spike_scale : float
25-
The scale of the spike spike (Normal) distribution.
22+
spike_scale : float or None
23+
The scale of the spike (Normal) distribution. Automatically determined if None (see “Notes” section).
24+
slab_scale : float or None
25+
The scale of the slab (Cauchy) distribution. Automatically determined if None (see “Notes” section).
2626
seed : int or None
2727
The seed for the random number generator. If None, a random seed is used. Used instead of np.random.Generator
2828
here to enable easy serialization.
2929
3030
Notes
3131
-----
32-
The spike-and-slab distribution consists of a mixture of a Cauchy (slab) and a Normal distribution (spike), which
33-
are applied based on a Bernoulli random variable with p=0.5.
32+
The spike-and-slab distribution consists of a mixture of a Normal distribution (spike) and Cauchy distribution
33+
(slab), which are applied based on a Bernoulli random variable with p=0.5.
3434
35-
The default scales follow [1] and expect standardized data (e.g., via the `Standardize` adapter). It is therefore
36-
recommended to adapt the scales when using unstandardized training data.
35+
The scales of the spike and slab distributions can be set manually, or they are automatically determined by scaling
36+
the default scales of [1] (which expect standardized data) by the standard deviation of the input data.
3737
3838
Examples
3939
--------
4040
>>> adapter = bf.Adapter().nnpe(["x"])
4141
"""
4242

43-
def __init__(self, *, slab_scale: float = 0.25, spike_scale: float = 0.01, seed: int = None):
43+
DEFAULT_SLAB = 0.25
44+
DEFAULT_SPIKE = 0.01
45+
46+
def __init__(self, *, spike_scale: float | None = None, slab_scale: float | None = None, seed: int | None = None):
4447
super().__init__()
45-
self.slab_scale = slab_scale
4648
self.spike_scale = spike_scale
49+
self.slab_scale = slab_scale
4750
self.seed = seed
4851
self.rng = np.random.default_rng(seed)
4952

5053
def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.ndarray:
5154
"""
52-
Add spike‐and‐slab noise (see “Notes” section of the class docstring for details) to `data` during training.
55+
Add spike‐and‐slab noise to `data` during training, using automatic scale determination if not provided (see
56+
“Notes” section of the class docstring for details).
5357
5458
Parameters
5559
----------
@@ -67,9 +71,21 @@ def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.nd
6771
"""
6872
if stage != "training":
6973
return data
74+
75+
# Check data validity
76+
if not np.all(np.isfinite(data)):
77+
raise ValueError("NNPE.forward: `data` contains NaN or infinite values.")
78+
79+
# Automatically determine scales if not provided
80+
if self.spike_scale is None or self.slab_scale is None:
81+
data_std = np.std(data)
82+
spike_scale = self.spike_scale if self.spike_scale is not None else self.DEFAULT_SPIKE * data_std
83+
slab_scale = self.slab_scale if self.slab_scale is not None else self.DEFAULT_SLAB * data_std
84+
85+
# Apply spike-and-slab noise
7086
mixture_mask = self.rng.binomial(n=1, p=0.5, size=data.shape).astype(bool)
71-
noise_slab = self.rng.standard_cauchy(size=data.shape) * self.slab_scale
72-
noise_spike = self.rng.standard_normal(size=data.shape) * self.spike_scale
87+
noise_spike = self.rng.standard_normal(size=data.shape) * spike_scale
88+
noise_slab = self.rng.standard_cauchy(size=data.shape) * slab_scale
7389
noise = np.where(mixture_mask, noise_slab, noise_spike)
7490
return data + noise
7591

@@ -78,4 +94,4 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
7894
return data
7995

8096
def get_config(self) -> dict:
81-
return serialize({"slab_scale": self.slab_scale, "spike_scale": self.spike_scale, "seed": self.seed})
97+
return serialize({"spike_scale": self.spike_scale, "slab_scale": self.slab_scale, "seed": self.seed})

tests/test_adapters/test_adapters.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def test_nnpe(random_data):
304304
import numpy as np
305305
from bayesflow.adapters import Adapter
306306

307-
ad = Adapter().nnpe("x1", slab_scale=1.0, spike_scale=1.0, seed=42)
307+
ad = Adapter().nnpe("x1", spike_scale=1.0, slab_scale=1.0, seed=42)
308308
result_training = ad(random_data, stage="training")
309309
result_validation = ad(random_data, stage="validation")
310310
result_inference = ad(random_data, stage="inference")
@@ -330,3 +330,22 @@ def test_nnpe(random_data):
330330
assert np.allclose(result_validation[k], v)
331331
assert np.allclose(result_inference[k], v)
332332
assert np.allclose(result_inversed[k], v)
333+
334+
# Test at least one scale is None case (automatic scale determination)
335+
ad_partial = Adapter().nnpe("x2", slab_scale=None, spike_scale=1.0, seed=42)
336+
result_training_partial = ad_partial(random_data, stage="training")
337+
assert not np.allclose(result_training_partial["x2"], random_data["x2"])
338+
339+
# Test both scales and seed are None case (automatic scale determination)
340+
ad_auto = Adapter().nnpe("y1", slab_scale=None, spike_scale=None, seed=None)
341+
result_training_auto = ad_auto(random_data, stage="training")
342+
assert not np.allclose(result_training_auto["y1"], random_data["y1"])
343+
for k, v in random_data.items():
344+
if k == "y1":
345+
continue
346+
assert np.allclose(result_training_auto[k], v)
347+
348+
serialized_auto = serialize(ad_auto)
349+
deserialized_auto = deserialize(serialized_auto)
350+
reserialized_auto = serialize(deserialized_auto)
351+
assert keras.tree.lists_to_tuples(serialized_auto) == keras.tree.lists_to_tuples(serialize(reserialized_auto))

0 commit comments

Comments
 (0)