Skip to content

Commit 4931fac

Browse files
committed
Adjust class name and add docstring to forward method
1 parent 6c8744b commit 4931fac

File tree

3 files changed

+23
-5
lines changed

3 files changed

+23
-5
lines changed

bayesflow/adapters/adapter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
Keep,
1919
Log,
2020
MapTransform,
21-
Nnpe,
21+
NNPE,
2222
NumpyTransform,
2323
OneHot,
2424
Rename,
@@ -708,7 +708,7 @@ def nnpe(
708708
spike_scale: float = 0.01,
709709
seed: int | None = None,
710710
):
711-
"""Append an :py:class:`~transforms.Nnpe` transform to the adapter.
711+
"""Append an :py:class:`~transforms.NNPE` transform to the adapter.
712712
713713
Parameters
714714
----------
@@ -724,7 +724,7 @@ def nnpe(
724724
if isinstance(keys, str):
725725
keys = [keys]
726726

727-
transform = MapTransform({key: Nnpe(slab_scale=slab_scale, spike_scale=spike_scale, seed=seed) for key in keys})
727+
transform = MapTransform({key: NNPE(slab_scale=slab_scale, spike_scale=spike_scale, seed=seed) for key in keys})
728728
self.transforms.append(transform)
729729
return self
730730

bayesflow/adapters/transforms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .keep import Keep
1313
from .log import Log
1414
from .map_transform import MapTransform
15-
from .nnpe import Nnpe
15+
from .nnpe import NNPE
1616
from .numpy_transform import NumpyTransform
1717
from .one_hot import OneHot
1818
from .rename import Rename

bayesflow/adapters/transforms/nnpe.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
@serializable("bayesflow.adapters")
9-
class Nnpe(ElementwiseTransform):
9+
class NNPE(ElementwiseTransform):
1010
"""Implements noisy neural posterior estimation (NNPE) as described in [1], which adds noise following a
1111
spike-and-slab distribution to the training data as a mild form of data augmentation to robustify against noisy
1212
real-world data (see [1, 2] for benchmarks).
@@ -48,6 +48,23 @@ def __init__(self, *, slab_scale: float = 0.25, spike_scale: float = 0.01, seed:
4848
self.rng = np.random.default_rng(seed)
4949

5050
def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.ndarray:
51+
"""
52+
Add spike‐and‐slab noise (see “Notes” section of the class docstring for details) to `data` during training.
53+
54+
Parameters
55+
----------
56+
data : np.ndarray
57+
Input array to be perturbed.
58+
stage : str, default='inference'
59+
If 'training', noise is added; else data is returned unchanged.
60+
**kwargs
61+
Unused keyword arguments.
62+
63+
Returns
64+
-------
65+
np.ndarray
66+
Noisy data when `stage` is 'training', otherwise the original input.
67+
"""
5168
if stage != "training":
5269
return data
5370
mixture_mask = self.rng.binomial(n=1, p=0.5, size=data.shape).astype(bool)
@@ -57,6 +74,7 @@ def forward(self, data: np.ndarray, stage: str = "inference", **kwargs) -> np.nd
5774
return data + noise
5875

5976
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
77+
"""Non-invertible transform."""
6078
return data
6179

6280
def get_config(self) -> dict:

0 commit comments

Comments
 (0)