Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
Ungroup,
RandomSubsample,
Take,
NanToNum,
)
from .transforms.filter_transform import Predicate

Expand Down Expand Up @@ -956,3 +957,34 @@ def to_dict(self):
transform = ToDict()
self.transforms.append(transform)
return self

def nan_to_num(
self,
keys: str | Sequence[str],
default_value: float = 0.0,
return_mask: bool = False,
mask_prefix: str = "mask",
):
"""
Append :py:class:`~bf.adapters.transforms.NanToNum` transform to the adapter.

Parameters
----------
keys : str or sequence of str
The names of the variables to clean / mask.
default_value : float
Value to substitute wherever data is NaN. Defaults to 0.0.
return_mask : bool
If True, encode a binary missingness mask alongside the data. Defaults to False.
mask_prefix : str
Prefix for the mask key in the output dictionary. Defaults to 'mask_'. If the mask key already exists,
a ValueError is raised to avoid overwriting existing masks.
"""
if isinstance(keys, str):
keys = [keys]

for key in keys:
self.transforms.append(
NanToNum(key=key, default_value=default_value, return_mask=return_mask, mask_prefix=mask_prefix)
)
return self
1 change: 1 addition & 0 deletions bayesflow/adapters/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .random_subsample import RandomSubsample
from .take import Take
from .ungroup import Ungroup
from .nan_to_num import NanToNum

from ...utils._docs import _add_imports_to_all

Expand Down
91 changes: 91 additions & 0 deletions bayesflow/adapters/transforms/nan_to_num.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import numpy as np

from bayesflow.utils.serialization import serializable, serialize
from .transform import Transform


@serializable("bayesflow.adapters")
class NanToNum(Transform):
"""
Replace NaNs with a default value, and optionally encode a missing-data mask as a separate output key.

This is based on "Missing data in amortized simulation-based neural posterior estimation" by Wang et al. (2024).

Parameters
----------
default_value : float
Value to substitute wherever data is NaN.
return_mask : bool, default=False
If True, a mask array will be returned under a new key.
mask_prefix : str, default='mask_'
Prefix for the mask key in the output dictionary.
"""

def __init__(self, key: str, default_value: float = 0.0, return_mask: bool = False, mask_prefix: str = "mask"):
super().__init__()
self.key = key
self.default_value = default_value
self.return_mask = return_mask
self.mask_prefix = mask_prefix

def get_config(self) -> dict:
return serialize(

Check warning on line 32 in bayesflow/adapters/transforms/nan_to_num.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nan_to_num.py#L32

Added line #L32 was not covered by tests
{
"key": self.key,
"default_value": self.default_value,
"return_mask": self.return_mask,
"mask_prefix": self.mask_prefix,
}
)

@property
def mask_key(self) -> str:
"""
Key under which the mask will be stored in the output dictionary.
"""
return f"{self.mask_prefix}_{self.key}"

def forward(self, data: dict[str, any], **kwargs) -> dict[str, any]:
"""
Forward transform: fill NaNs and optionally output mask under 'mask_<key>'.
"""
data = data.copy()

# Check if the mask key already exists in the data
if self.mask_key in data.keys():
raise ValueError(

Check warning on line 56 in bayesflow/adapters/transforms/nan_to_num.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nan_to_num.py#L56

Added line #L56 was not covered by tests
f"Mask key '{self.mask_key}' already exists in the data. Please choose a different mask_prefix."
)

# Identify NaNs and fill with default value
mask = np.isnan(data[self.key])
data[self.key] = np.nan_to_num(data[self.key], copy=False, nan=self.default_value)

if not self.return_mask:
return data

# Prepare mask array (1 for valid, 0 for NaN)
mask_array = (~mask).astype(np.int8)

# Return both the filled data and the mask under separate keys
data[self.mask_key] = mask_array
return data

def inverse(self, data: dict[str, any], **kwargs) -> dict[str, any]:
"""
Inverse transform: restore NaNs using the mask under 'mask_<key>'.
"""
data = data.copy()

Check warning on line 78 in bayesflow/adapters/transforms/nan_to_num.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nan_to_num.py#L78

Added line #L78 was not covered by tests

# Retrieve mask and values to reconstruct NaNs
values = data[self.key]

Check warning on line 81 in bayesflow/adapters/transforms/nan_to_num.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nan_to_num.py#L81

Added line #L81 was not covered by tests

if not self.return_mask:
values[values == self.default_value] = np.nan # we assume default_value is not in data

Check warning on line 84 in bayesflow/adapters/transforms/nan_to_num.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nan_to_num.py#L83-L84

Added lines #L83 - L84 were not covered by tests
else:
mask_array = data[self.mask_key].astype(bool)

Check warning on line 86 in bayesflow/adapters/transforms/nan_to_num.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nan_to_num.py#L86

Added line #L86 was not covered by tests
# Put NaNs where mask is 0
values[~mask_array] = np.nan

Check warning on line 88 in bayesflow/adapters/transforms/nan_to_num.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nan_to_num.py#L88

Added line #L88 was not covered by tests

data[self.key] = values
return data

Check warning on line 91 in bayesflow/adapters/transforms/nan_to_num.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/nan_to_num.py#L90-L91

Added lines #L90 - L91 were not covered by tests
21 changes: 21 additions & 0 deletions tests/test_adapters/test_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,27 @@ def test_log_det_jac_exceptions(random_data):
assert np.allclose(forward_log_det_jac["p"], -inverse_log_det_jac)


def test_nan_to_num():
arr = {"test": np.array([1.0, np.nan, 3.0])}
# test without mask
transform = bf.Adapter().nan_to_num(keys="test", default_value=-1.0, return_mask=False)
out = transform.forward(arr)["test"]
np.testing.assert_array_equal(out, np.array([1.0, -1.0, 3.0]))

# test with mask
arr = {"test": np.array([1.0, np.nan, 3.0]), "test-2d": np.array([[1.0, np.nan], [np.nan, 4.0]])}
transform = bf.Adapter().nan_to_num(keys="test", default_value=0.0, return_mask=True)
out = transform.forward(arr)
np.testing.assert_array_equal(out["test"], np.array([1.0, 0.0, 3.0]))
np.testing.assert_array_equal(out["mask_test"], np.array([1.0, 0.0, 1.0]))

# test two-d array
transform = bf.Adapter().nan_to_num(keys="test-2d", default_value=0.5, return_mask=True, mask_prefix="new_mask")
out = transform.forward(arr)
np.testing.assert_array_equal(out["test-2d"], np.array([[1.0, 0.5], [0.5, 4.0]]))
np.testing.assert_array_equal(out["new_mask_test-2d"], np.array([[1, 0], [0, 1]]))


def test_nnpe(random_data):
# NNPE cannot be integrated into the adapter fixture and its tests since it modifies the input data
# and therefore breaks existing allclose checks
Expand Down