Skip to content

Commit b5c946b

Browse files
committed
changed name to return_mask
1 parent a2eadd9 commit b5c946b

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

bayesflow/adapters/adapter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -924,7 +924,7 @@ def nan_to_num(
924924
self,
925925
keys: str | Sequence[str],
926926
default_value: float = 0.0,
927-
encode_mask: bool = False,
927+
return_mask: bool = False,
928928
):
929929
"""
930930
Append :py:class:`~bf.adapters.transforms.NanToNum` transform to the adapter.
@@ -935,12 +935,12 @@ def nan_to_num(
935935
The names of the variables to clean / mask.
936936
default_value : float
937937
Value to substitute wherever data is NaN.
938-
encode_mask : bool
938+
return_mask : bool
939939
If True, encode a binary missingness mask alongside the data.
940940
"""
941941
if isinstance(keys, str):
942942
keys = [keys]
943943

944944
for key in keys:
945-
self.transforms.append(NanToNum(key=key, default_value=default_value, encode_mask=encode_mask))
945+
self.transforms.append(NanToNum(key=key, default_value=default_value, return_mask=return_mask))
946946
return self

bayesflow/adapters/transforms/nan_to_num.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,22 @@ class NanToNum(Transform):
1515
----------
1616
default_value : float
1717
Value to substitute wherever data is NaN.
18-
encode_mask : bool, default=False
18+
return_mask : bool, default=False
1919
If True, a mask array will be returned under a new key.
2020
"""
2121

22-
def __init__(self, key: str, default_value: float = 0.0, encode_mask: bool = False):
22+
def __init__(self, key: str, default_value: float = 0.0, return_mask: bool = False):
2323
super().__init__()
2424
self.key = key
2525
self.default_value = default_value
26-
self.encode_mask = encode_mask
26+
self.return_mask = return_mask
2727

2828
def get_config(self) -> dict:
2929
return serialize(
3030
{
3131
"key": self.key,
3232
"default_value": self.default_value,
33-
"encode_mask": self.encode_mask,
33+
"return_mask": self.return_mask,
3434
}
3535
)
3636

@@ -51,7 +51,7 @@ def forward(self, data: dict[str, any], **kwargs) -> dict[str, any]:
5151
mask = np.isnan(data[self.key])
5252
data[self.key] = np.nan_to_num(data[self.key], copy=False, nan=self.default_value)
5353

54-
if not self.encode_mask:
54+
if not self.return_mask:
5555
return data
5656

5757
# Prepare mask array (1 for valid, 0 for NaN)
@@ -70,7 +70,7 @@ def inverse(self, data: dict[str, any], **kwargs) -> dict[str, any]:
7070
# Retrieve mask and values to reconstruct NaNs
7171
values = data[self.key]
7272

73-
if not self.encode_mask:
73+
if not self.return_mask:
7474
values[values == self.default_value] = np.nan # we assume default_value is not in data
7575
else:
7676
mask_array = data[self.mask_key].astype(bool)

tests/test_adapters/test_adapters.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,19 +301,19 @@ def test_log_det_jac_exceptions(random_data):
301301
def test_nan_to_num():
302302
arr = {"test": np.array([1.0, np.nan, 3.0])}
303303
# test without mask
304-
transform = bf.Adapter().nan_to_num(keys="test", default_value=-1.0, encode_mask=False)
304+
transform = bf.Adapter().nan_to_num(keys="test", default_value=-1.0, return_mask=False)
305305
out = transform.forward(arr)["test"]
306306
np.testing.assert_array_equal(out, np.array([1.0, -1.0, 3.0]))
307307

308308
# test with mask
309309
arr = {"test": np.array([1.0, np.nan, 3.0]), "test-2d": np.array([[1.0, np.nan], [np.nan, 4.0]])}
310-
transform = bf.Adapter().nan_to_num(keys="test", default_value=0.0, encode_mask=True)
310+
transform = bf.Adapter().nan_to_num(keys="test", default_value=0.0, return_mask=True)
311311
out = transform.forward(arr)
312312
np.testing.assert_array_equal(out["test"], np.array([1.0, 0.0, 3.0]))
313313
np.testing.assert_array_equal(out["mask_test"], np.array([1.0, 0.0, 1.0]))
314314

315315
# test two-d array
316-
transform = bf.Adapter().nan_to_num(keys="test-2d", default_value=0.5, encode_mask=True)
316+
transform = bf.Adapter().nan_to_num(keys="test-2d", default_value=0.5, return_mask=True)
317317
out = transform.forward(arr)
318318
np.testing.assert_array_equal(out["test-2d"], np.array([[1.0, 0.5], [0.5, 4.0]]))
319319
np.testing.assert_array_equal(out["mask_test-2d"], np.array([[1, 0], [0, 1]]))

0 commit comments

Comments
 (0)