Skip to content

Commit e076518

Browse files
committed
add mask naming
1 parent b5c946b commit e076518

File tree

3 files changed

+23
-7
lines changed

3 files changed

+23
-7
lines changed

bayesflow/adapters/adapter.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -925,6 +925,7 @@ def nan_to_num(
925925
keys: str | Sequence[str],
926926
default_value: float = 0.0,
927927
return_mask: bool = False,
928+
mask_prefix: str = "mask",
928929
):
929930
"""
930931
Append :py:class:`~bf.adapters.transforms.NanToNum` transform to the adapter.
@@ -934,13 +935,18 @@ def nan_to_num(
934935
keys : str or sequence of str
935936
The names of the variables to clean / mask.
936937
default_value : float
937-
Value to substitute wherever data is NaN.
938+
Value to substitute wherever data is NaN. Defaults to 0.0.
938939
return_mask : bool
939-
If True, encode a binary missingness mask alongside the data.
940+
If True, encode a binary missingness mask alongside the data. Defaults to False.
941+
mask_prefix : str
942+
Prefix for the mask key in the output dictionary. Defaults to 'mask_'. If the mask key already exists,
943+
a ValueError is raised to avoid overwriting existing masks.
940944
"""
941945
if isinstance(keys, str):
942946
keys = [keys]
943947

944948
for key in keys:
945-
self.transforms.append(NanToNum(key=key, default_value=default_value, return_mask=return_mask))
949+
self.transforms.append(
950+
NanToNum(key=key, default_value=default_value, return_mask=return_mask, mask_prefix=mask_prefix)
951+
)
946952
return self

bayesflow/adapters/transforms/nan_to_num.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,24 @@ class NanToNum(Transform):
1717
Value to substitute wherever data is NaN.
1818
return_mask : bool, default=False
1919
If True, a mask array will be returned under a new key.
20+
mask_prefix : str, default='mask_'
21+
Prefix for the mask key in the output dictionary.
2022
"""
2123

22-
def __init__(self, key: str, default_value: float = 0.0, return_mask: bool = False):
24+
def __init__(self, key: str, default_value: float = 0.0, return_mask: bool = False, mask_prefix: str = "mask"):
2325
super().__init__()
2426
self.key = key
2527
self.default_value = default_value
2628
self.return_mask = return_mask
29+
self.mask_prefix = mask_prefix
2730

2831
def get_config(self) -> dict:
2932
return serialize(
3033
{
3134
"key": self.key,
3235
"default_value": self.default_value,
3336
"return_mask": self.return_mask,
37+
"mask_prefix": self.mask_prefix,
3438
}
3539
)
3640

@@ -39,14 +43,20 @@ def mask_key(self) -> str:
3943
"""
4044
Key under which the mask will be stored in the output dictionary.
4145
"""
42-
return f"mask_{self.key}" if self.key else "mask"
46+
return f"{self.mask_prefix}_{self.key}"
4347

4448
def forward(self, data: dict[str, any], **kwargs) -> dict[str, any]:
4549
"""
4650
Forward transform: fill NaNs and optionally output mask under 'mask_<key>'.
4751
"""
4852
data = data.copy()
4953

54+
# Check if the mask key already exists in the data
55+
if self.mask_key in data.keys():
56+
raise ValueError(
57+
f"Mask key '{self.mask_key}' already exists in the data. Please choose a different mask_prefix."
58+
)
59+
5060
# Identify NaNs and fill with default value
5161
mask = np.isnan(data[self.key])
5262
data[self.key] = np.nan_to_num(data[self.key], copy=False, nan=self.default_value)

tests/test_adapters/test_adapters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def test_nan_to_num():
313313
np.testing.assert_array_equal(out["mask_test"], np.array([1.0, 0.0, 1.0]))
314314

315315
# test two-d array
316-
transform = bf.Adapter().nan_to_num(keys="test-2d", default_value=0.5, return_mask=True)
316+
transform = bf.Adapter().nan_to_num(keys="test-2d", default_value=0.5, return_mask=True, mask_prefix="new_mask")
317317
out = transform.forward(arr)
318318
np.testing.assert_array_equal(out["test-2d"], np.array([[1.0, 0.5], [0.5, 4.0]]))
319-
np.testing.assert_array_equal(out["mask_test-2d"], np.array([[1, 0], [0, 1]]))
319+
np.testing.assert_array_equal(out["new_mask_test-2d"], np.array([[1, 0], [0, 1]]))

0 commit comments

Comments
 (0)