diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index 900025583..fa84a9b4f 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -31,6 +31,7 @@ Ungroup, RandomSubsample, Take, + NanToNum, ) from .transforms.filter_transform import Predicate @@ -956,3 +957,34 @@ def to_dict(self): transform = ToDict() self.transforms.append(transform) return self + + def nan_to_num( + self, + keys: str | Sequence[str], + default_value: float = 0.0, + return_mask: bool = False, + mask_prefix: str = "mask", + ): + """ + Append :py:class:`~bf.adapters.transforms.NanToNum` transform to the adapter. + + Parameters + ---------- + keys : str or sequence of str + The names of the variables to clean / mask. + default_value : float + Value to substitute wherever data is NaN. Defaults to 0.0. + return_mask : bool + If True, encode a binary missingness mask alongside the data. Defaults to False. + mask_prefix : str + Prefix for the mask key in the output dictionary. Defaults to 'mask_'. If the mask key already exists, + a ValueError is raised to avoid overwriting existing masks. + """ + if isinstance(keys, str): + keys = [keys] + + for key in keys: + self.transforms.append( + NanToNum(key=key, default_value=default_value, return_mask=return_mask, mask_prefix=mask_prefix) + ) + return self diff --git a/bayesflow/adapters/transforms/__init__.py b/bayesflow/adapters/transforms/__init__.py index df8918091..bf449ab33 100644 --- a/bayesflow/adapters/transforms/__init__.py +++ b/bayesflow/adapters/transforms/__init__.py @@ -29,6 +29,7 @@ from .random_subsample import RandomSubsample from .take import Take from .ungroup import Ungroup +from .nan_to_num import NanToNum from ...utils._docs import _add_imports_to_all diff --git a/bayesflow/adapters/transforms/nan_to_num.py b/bayesflow/adapters/transforms/nan_to_num.py new file mode 100644 index 000000000..fe715a174 --- /dev/null +++ b/bayesflow/adapters/transforms/nan_to_num.py @@ -0,0 +1,91 @@ +import numpy as np + +from bayesflow.utils.serialization import serializable, serialize +from .transform import Transform + + +@serializable("bayesflow.adapters") +class NanToNum(Transform): + """ + Replace NaNs with a default value, and optionally encode a missing-data mask as a separate output key. + + This is based on "Missing data in amortized simulation-based neural posterior estimation" by Wang et al. (2024). + + Parameters + ---------- + default_value : float + Value to substitute wherever data is NaN. + return_mask : bool, default=False + If True, a mask array will be returned under a new key. + mask_prefix : str, default='mask_' + Prefix for the mask key in the output dictionary. + """ + + def __init__(self, key: str, default_value: float = 0.0, return_mask: bool = False, mask_prefix: str = "mask"): + super().__init__() + self.key = key + self.default_value = default_value + self.return_mask = return_mask + self.mask_prefix = mask_prefix + + def get_config(self) -> dict: + return serialize( + { + "key": self.key, + "default_value": self.default_value, + "return_mask": self.return_mask, + "mask_prefix": self.mask_prefix, + } + ) + + @property + def mask_key(self) -> str: + """ + Key under which the mask will be stored in the output dictionary. + """ + return f"{self.mask_prefix}_{self.key}" + + def forward(self, data: dict[str, any], **kwargs) -> dict[str, any]: + """ + Forward transform: fill NaNs and optionally output mask under 'mask_'. + """ + data = data.copy() + + # Check if the mask key already exists in the data + if self.mask_key in data.keys(): + raise ValueError( + f"Mask key '{self.mask_key}' already exists in the data. Please choose a different mask_prefix." + ) + + # Identify NaNs and fill with default value + mask = np.isnan(data[self.key]) + data[self.key] = np.nan_to_num(data[self.key], copy=False, nan=self.default_value) + + if not self.return_mask: + return data + + # Prepare mask array (1 for valid, 0 for NaN) + mask_array = (~mask).astype(np.int8) + + # Return both the filled data and the mask under separate keys + data[self.mask_key] = mask_array + return data + + def inverse(self, data: dict[str, any], **kwargs) -> dict[str, any]: + """ + Inverse transform: restore NaNs using the mask under 'mask_'. + """ + data = data.copy() + + # Retrieve mask and values to reconstruct NaNs + values = data[self.key] + + if not self.return_mask: + values[values == self.default_value] = np.nan # we assume default_value is not in data + else: + mask_array = data[self.mask_key].astype(bool) + # Put NaNs where mask is 0 + values[~mask_array] = np.nan + + data[self.key] = values + return data diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index 7e4314f83..23721a938 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -298,6 +298,27 @@ def test_log_det_jac_exceptions(random_data): assert np.allclose(forward_log_det_jac["p"], -inverse_log_det_jac) +def test_nan_to_num(): + arr = {"test": np.array([1.0, np.nan, 3.0])} + # test without mask + transform = bf.Adapter().nan_to_num(keys="test", default_value=-1.0, return_mask=False) + out = transform.forward(arr)["test"] + np.testing.assert_array_equal(out, np.array([1.0, -1.0, 3.0])) + + # test with mask + arr = {"test": np.array([1.0, np.nan, 3.0]), "test-2d": np.array([[1.0, np.nan], [np.nan, 4.0]])} + transform = bf.Adapter().nan_to_num(keys="test", default_value=0.0, return_mask=True) + out = transform.forward(arr) + np.testing.assert_array_equal(out["test"], np.array([1.0, 0.0, 3.0])) + np.testing.assert_array_equal(out["mask_test"], np.array([1.0, 0.0, 1.0])) + + # test two-d array + transform = bf.Adapter().nan_to_num(keys="test-2d", default_value=0.5, return_mask=True, mask_prefix="new_mask") + out = transform.forward(arr) + np.testing.assert_array_equal(out["test-2d"], np.array([[1.0, 0.5], [0.5, 4.0]])) + np.testing.assert_array_equal(out["new_mask_test-2d"], np.array([[1, 0], [0, 1]])) + + def test_nnpe(random_data): # NNPE cannot be integrated into the adapter fixture and its tests since it modifies the input data # and therefore breaks existing allclose checks