Skip to content

Commit 1b2aeae

Browse files
committed
add replace nan adapter
1 parent a7f9162 commit 1b2aeae

File tree

4 files changed

+145
-0
lines changed

4 files changed

+145
-0
lines changed

bayesflow/adapters/adapter.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
Standardize,
2626
ToArray,
2727
Transform,
28+
ReplaceNaN,
2829
)
2930
from .transforms.filter_transform import Predicate
3031

@@ -770,3 +771,33 @@ def to_dict(self):
770771
transform = ToDict()
771772
self.transforms.append(transform)
772773
return self
774+
775+
def replace_nan(
776+
self,
777+
keys: str | Sequence[str],
778+
default_value: float = 0.0,
779+
encode_mask: bool = False,
780+
axis: int | None = None,
781+
):
782+
"""
783+
Append :py:class:`~bf.adapters.transforms.ReplaceNaN` transform to the adapter.
784+
785+
Parameters
786+
----------
787+
keys : str or sequence of str
788+
The names of the variables to clean / mask.
789+
default_value : float
790+
Value to substitute wherever data is NaN.
791+
encode_mask : bool
792+
If True, encode a binary missingness mask alongside the data.
793+
axis : int or tuple or None
794+
Axis at which to expand for mask encoding (if enabled).
795+
"""
796+
if isinstance(keys, str):
797+
keys = [keys]
798+
799+
transform = MapTransform(
800+
{key: ReplaceNaN(default_value=default_value, encode_mask=encode_mask, axis=axis) for key in keys}
801+
)
802+
self.transforms.append(transform)
803+
return self

bayesflow/adapters/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .to_array import ToArray
2424
from .to_dict import ToDict
2525
from .transform import Transform
26+
from .replace_nan import ReplaceNaN
2627

2728
from ...utils._docs import _add_imports_to_all
2829

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import numpy as np
2+
3+
from bayesflow.utils.serialization import serializable, serialize
4+
from .elementwise_transform import ElementwiseTransform
5+
6+
7+
@serializable
8+
class ReplaceNaN(ElementwiseTransform):
9+
"""
10+
Replace NaNs with a default value, and optionally encode a missing‐data mask.
11+
12+
This is based on "Missing data in amortized simulation-based neural posterior estimation" by Wang et al. (2024).
13+
14+
Parameters
15+
----------
16+
default_value : float
17+
Value to substitute wherever data is NaN.
18+
encode_mask : bool, default=False
19+
If True, the forward pass will expand the array by one new axis and
20+
concatenate a binary mask (0 for originally-NaN entries, 1 otherwise).
21+
axis : int or None
22+
Axis along which to add the new dimension for mask encoding.
23+
If None, defaults to `data.ndim` (i.e., a new trailing axis).
24+
25+
Examples
26+
--------
27+
>>> a = np.array([1.0, np.nan, 3.0])
28+
>>> r_nan = bf.adapters.transforms.ReplaceNaN(default_value=0.0)
29+
>>> r_nan.forward(a)
30+
array([1., 0., 3.])
31+
32+
>>> # With mask encoding along a new last axis:
33+
>>> r_nan = bf.adapters.transforms.ReplaceNaN(default_value=-1.0, encode_mask=True, axis=-1)
34+
>>> enc = r_nan.forward(a)
35+
>>> enc.shape
36+
(3, 2)
37+
38+
It’s recommended to precede this with a ToArray transform if your data
39+
might not already be a NumPy array.
40+
"""
41+
42+
def __init__(
43+
self,
44+
*,
45+
default_value: float = 0.0,
46+
encode_mask: bool = False,
47+
axis: int | None = None,
48+
):
49+
super().__init__()
50+
self.default_value = default_value
51+
self.encode_mask = encode_mask
52+
self.axis = axis
53+
54+
def get_config(self) -> dict:
55+
return serialize(
56+
{
57+
"default_value": self.default_value,
58+
"encode_mask": self.encode_mask,
59+
"axis": self.axis,
60+
}
61+
)
62+
63+
def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
64+
# Create mask of where data is NaN
65+
mask = np.isnan(data)
66+
# Fill NaNs with the default value
67+
filled = np.where(mask, self.default_value, data)
68+
69+
if not self.encode_mask:
70+
return filled
71+
72+
# Decide where to insert the new axis
73+
ax = self.axis if self.axis is not None else data.ndim
74+
# Expand dims for both filled data and mask
75+
filled_exp = np.expand_dims(filled, axis=ax)
76+
mask_exp = 1 - np.expand_dims(mask.astype(np.int8), axis=ax)
77+
# Concatenate along that axis: [..., value, mask]
78+
return np.concatenate([filled_exp, mask_exp], axis=ax)
79+
80+
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
81+
if not self.encode_mask:
82+
# No mask was encoded, so nothing to undo
83+
return data
84+
85+
ax = self.axis if self.axis is not None else data.ndim - 1
86+
# Extract the two “channels”
87+
values = np.take(data, indices=0, axis=ax)
88+
mask = np.take(data, indices=1, axis=ax).astype(bool)
89+
# Restore NaNs where mask == 1
90+
values[mask] = np.nan
91+
return values

tests/test_adapters/test_adapters.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,25 @@ def test_to_dict_transform():
230230

231231
# category should have 5 one-hot categories, even though it was only passed 4
232232
assert processed["category"].shape[-1] == 5
233+
234+
235+
def test_replace_nan():
236+
arr = {"test": np.array([1.0, np.nan, 3.0]), "test-2d": np.array([[1.0, np.nan], [np.nan, 4.0]])}
237+
# test without mask
238+
transform = bf.Adapter().replace_nan(keys="test", default_value=-1.0, encode_mask=False)
239+
out = transform.forward(arr)["test"]
240+
np.testing.assert_array_equal(out, np.array([1.0, -1.0, 3.0]))
241+
242+
# test with mask
243+
transform = bf.Adapter().replace_nan(keys="test", default_value=0.0, encode_mask=True)
244+
out = transform.forward(arr)["test"]
245+
np.testing.assert_array_equal(out, np.array([[1.0, 1.0], [0.0, 0.0], [3.0, 1.0]]))
246+
247+
# test two-d array
248+
transform = bf.Adapter().replace_nan(keys="test-2d", default_value=0.5, encode_mask=True, axis=0)
249+
out = transform.forward(arr)["test-2d"]
250+
# Original shape (2,2) -> new shape (2,2,2) when expanding at axis=0
251+
# Channel 0 along axis 0 should be the filled values
252+
np.testing.assert_array_equal(out[0], np.array([[1.0, 0.5], [0.5, 4.0]]))
253+
# Channel 1 along axis 0 should be the mask
254+
np.testing.assert_array_equal(out[1], np.array([[1, 0], [0, 1]]))

0 commit comments

Comments
 (0)