Skip to content

Commit 3a8e313

Browse files
committed
improve
1 parent 4b776f0 commit 3a8e313

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

bayesflow/adapters/transforms/nan_to_num.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ def mask_key(self) -> str:
3939
"""
4040
Key under which the mask will be stored in the output dictionary.
4141
"""
42-
return f"_mask_{self.key}" if self.key else "_mask"
42+
return f"mask_{self.key}" if self.key else "mask"
4343

4444
def forward(self, data: dict[str, any], **kwargs) -> dict[str, any]:
4545
"""
46-
Forward transform: fill NaNs and optionally output mask under '_mask_<key>'.
46+
Forward transform: fill NaNs and optionally output mask under 'mask_<key>'.
4747
"""
4848
data = data.copy()
4949

@@ -63,7 +63,7 @@ def forward(self, data: dict[str, any], **kwargs) -> dict[str, any]:
6363

6464
def inverse(self, data: dict[str, any], **kwargs) -> dict[str, any]:
6565
"""
66-
Inverse transform: restore NaNs using the mask under '_mask_<key>'.
66+
Inverse transform: restore NaNs using the mask under 'mask_<key>'.
6767
"""
6868
data = data.copy()
6969

tests/test_adapters/test_adapters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,10 +303,10 @@ def test_nan_to_num():
303303
transform = bf.Adapter().nan_to_num(keys="test", default_value=0.0, encode_mask=True)
304304
out = transform.forward(arr)
305305
np.testing.assert_array_equal(out["test"], np.array([1.0, 0.0, 3.0]))
306-
np.testing.assert_array_equal(out["_mask_test"], np.array([1.0, 0.0, 1.0]))
306+
np.testing.assert_array_equal(out["mask_test"], np.array([1.0, 0.0, 1.0]))
307307

308308
# test two-d array
309309
transform = bf.Adapter().nan_to_num(keys="test-2d", default_value=0.5, encode_mask=True)
310310
out = transform.forward(arr)
311311
np.testing.assert_array_equal(out["test-2d"], np.array([[1.0, 0.5], [0.5, 4.0]]))
312-
np.testing.assert_array_equal(out["_mask_test-2d"], np.array([[1, 0], [0, 1]]))
312+
np.testing.assert_array_equal(out["mask_test-2d"], np.array([[1, 0], [0, 1]]))

0 commit comments

Comments
 (0)