Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Standardize,
ToArray,
Transform,
ReplaceNaN,
)
from .transforms.filter_transform import Predicate

Expand Down Expand Up @@ -791,3 +792,33 @@ def to_dict(self):
transform = ToDict()
self.transforms.append(transform)
return self

def replace_nan(
self,
keys: str | Sequence[str],
default_value: float = 0.0,
encode_mask: bool = False,
axis: int | None = None,
):
"""
Append :py:class:`~bf.adapters.transforms.ReplaceNaN` transform to the adapter.

Parameters
----------
keys : str or sequence of str
The names of the variables to clean / mask.
default_value : float
Value to substitute wherever data is NaN.
encode_mask : bool
If True, encode a binary missingness mask alongside the data.
axis : int or tuple or None
Axis at which to expand for mask encoding (if enabled).
"""
if isinstance(keys, str):
keys = [keys]

transform = MapTransform(
{key: ReplaceNaN(default_value=default_value, encode_mask=encode_mask, axis=axis) for key in keys}
)
self.transforms.append(transform)
return self
1 change: 1 addition & 0 deletions bayesflow/adapters/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .to_array import ToArray
from .to_dict import ToDict
from .transform import Transform
from .replace_nan import ReplaceNaN

from ...utils._docs import _add_imports_to_all

Expand Down
91 changes: 91 additions & 0 deletions bayesflow/adapters/transforms/replace_nan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import numpy as np

from bayesflow.utils.serialization import serializable, serialize
from .elementwise_transform import ElementwiseTransform


@serializable
class ReplaceNaN(ElementwiseTransform):
"""
Replace NaNs with a default value, and optionally encode a missing‐data mask.

This is based on "Missing data in amortized simulation-based neural posterior estimation" by Wang et al. (2024).

Parameters
----------
default_value : float
Value to substitute wherever data is NaN.
encode_mask : bool, default=False
If True, the forward pass will expand the array by one new axis and
concatenate a binary mask (0 for originally-NaN entries, 1 otherwise).
axis : int or None
Axis along which to add the new dimension for mask encoding.
If None, defaults to `data.ndim` (i.e., a new trailing axis).

Examples
--------
>>> a = np.array([1.0, np.nan, 3.0])
>>> r_nan = bf.adapters.transforms.ReplaceNaN(default_value=0.0)
>>> r_nan.forward(a)
array([1., 0., 3.])

>>> # With mask encoding along a new last axis:
>>> r_nan = bf.adapters.transforms.ReplaceNaN(default_value=-1.0, encode_mask=True, axis=-1)
>>> enc = r_nan.forward(a)
>>> enc.shape
(3, 2)

It’s recommended to precede this with a ToArray transform if your data
might not already be a NumPy array.
"""

def __init__(
self,
*,
default_value: float = 0.0,
encode_mask: bool = False,
axis: int | None = None,
):
super().__init__()
self.default_value = default_value
self.encode_mask = encode_mask
self.axis = axis

def get_config(self) -> dict:
return serialize(

Check warning on line 55 in bayesflow/adapters/transforms/replace_nan.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/replace_nan.py#L55

Added line #L55 was not covered by tests
{
"default_value": self.default_value,
"encode_mask": self.encode_mask,
"axis": self.axis,
}
)

def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
# Create mask of where data is NaN
mask = np.isnan(data)
# Fill NaNs with the default value
filled = np.where(mask, self.default_value, data)

if not self.encode_mask:
return filled

# Decide where to insert the new axis
ax = self.axis if self.axis is not None else data.ndim
# Expand dims for both filled data and mask
filled_exp = np.expand_dims(filled, axis=ax)
mask_exp = 1 - np.expand_dims(mask.astype(np.int8), axis=ax)
# Concatenate along that axis: [..., value, mask]
return np.concatenate([filled_exp, mask_exp], axis=ax)

def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
if not self.encode_mask:

Check warning on line 81 in bayesflow/adapters/transforms/replace_nan.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/replace_nan.py#L81

Added line #L81 was not covered by tests
# No mask was encoded, so nothing to undo
return data

Check warning on line 83 in bayesflow/adapters/transforms/replace_nan.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/replace_nan.py#L83

Added line #L83 was not covered by tests

ax = self.axis if self.axis is not None else data.ndim - 1

Check warning on line 85 in bayesflow/adapters/transforms/replace_nan.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/replace_nan.py#L85

Added line #L85 was not covered by tests
# Extract the two “channels”
values = np.take(data, indices=0, axis=ax)
mask = np.take(data, indices=1, axis=ax).astype(bool)

Check warning on line 88 in bayesflow/adapters/transforms/replace_nan.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/replace_nan.py#L87-L88

Added lines #L87 - L88 were not covered by tests
# Restore NaNs where mask == 1
values[mask] = np.nan
return values

Check warning on line 91 in bayesflow/adapters/transforms/replace_nan.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/replace_nan.py#L90-L91

Added lines #L90 - L91 were not covered by tests
22 changes: 22 additions & 0 deletions tests/test_adapters/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,25 @@ def test_log_det_jac_exceptions(random_data):

# inverse works when concatenation is used after transforms
assert np.allclose(forward_log_det_jac["p"], -inverse_log_det_jac)


def test_replace_nan():
arr = {"test": np.array([1.0, np.nan, 3.0]), "test-2d": np.array([[1.0, np.nan], [np.nan, 4.0]])}
# test without mask
transform = bf.Adapter().replace_nan(keys="test", default_value=-1.0, encode_mask=False)
out = transform.forward(arr)["test"]
np.testing.assert_array_equal(out, np.array([1.0, -1.0, 3.0]))

# test with mask
transform = bf.Adapter().replace_nan(keys="test", default_value=0.0, encode_mask=True)
out = transform.forward(arr)["test"]
np.testing.assert_array_equal(out, np.array([[1.0, 1.0], [0.0, 0.0], [3.0, 1.0]]))

# test two-d array
transform = bf.Adapter().replace_nan(keys="test-2d", default_value=0.5, encode_mask=True, axis=0)
out = transform.forward(arr)["test-2d"]
# Original shape (2,2) -> new shape (2,2,2) when expanding at axis=0
# Channel 0 along axis 0 should be the filled values
np.testing.assert_array_equal(out[0], np.array([[1.0, 0.5], [0.5, 4.0]]))
# Channel 1 along axis 0 should be the mask
np.testing.assert_array_equal(out[1], np.array([[1, 0], [0, 1]]))