Skip to content

Commit 4b776f0

Browse files
committed
update test
1 parent 93cf09f commit 4b776f0

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

bayesflow/adapters/transforms/nan_to_num.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def forward(self, data: dict[str, any], **kwargs) -> dict[str, any]:
4545
"""
4646
Forward transform: fill NaNs and optionally output mask under '_mask_<key>'.
4747
"""
48+
data = data.copy()
49+
4850
# Identify NaNs and fill with default value
4951
mask = np.isnan(data[self.key])
5052
data[self.key] = np.nan_to_num(data[self.key], copy=False, nan=self.default_value)
@@ -63,6 +65,8 @@ def inverse(self, data: dict[str, any], **kwargs) -> dict[str, any]:
6365
"""
6466
Inverse transform: restore NaNs using the mask under '_mask_<key>'.
6567
"""
68+
data = data.copy()
69+
6670
# Retrieve mask and values to reconstruct NaNs
6771
values = data[self.key]
6872

tests/test_adapters/test_adapters.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -291,23 +291,22 @@ def test_log_det_jac_exceptions(random_data):
291291
assert np.allclose(forward_log_det_jac["p"], -inverse_log_det_jac)
292292

293293

294-
def test_replace_nan():
295-
arr = {"test": np.array([1.0, np.nan, 3.0]), "test-2d": np.array([[1.0, np.nan], [np.nan, 4.0]])}
294+
def test_nan_to_num():
295+
arr = {"test": np.array([1.0, np.nan, 3.0])}
296296
# test without mask
297-
transform = bf.Adapter().replace_nan(keys="test", default_value=-1.0, encode_mask=False)
297+
transform = bf.Adapter().nan_to_num(keys="test", default_value=-1.0, encode_mask=False)
298298
out = transform.forward(arr)["test"]
299299
np.testing.assert_array_equal(out, np.array([1.0, -1.0, 3.0]))
300300

301301
# test with mask
302-
transform = bf.Adapter().replace_nan(keys="test", default_value=0.0, encode_mask=True)
303-
out = transform.forward(arr)["test"]
304-
np.testing.assert_array_equal(out, np.array([[1.0, 1.0], [0.0, 0.0], [3.0, 1.0]]))
302+
arr = {"test": np.array([1.0, np.nan, 3.0]), "test-2d": np.array([[1.0, np.nan], [np.nan, 4.0]])}
303+
transform = bf.Adapter().nan_to_num(keys="test", default_value=0.0, encode_mask=True)
304+
out = transform.forward(arr)
305+
np.testing.assert_array_equal(out["test"], np.array([1.0, 0.0, 3.0]))
306+
np.testing.assert_array_equal(out["_mask_test"], np.array([1.0, 0.0, 1.0]))
305307

306308
# test two-d array
307-
transform = bf.Adapter().replace_nan(keys="test-2d", default_value=0.5, encode_mask=True, axis=0)
308-
out = transform.forward(arr)["test-2d"]
309-
# Original shape (2,2) -> new shape (2,2,2) when expanding at axis=0
310-
# Channel 0 along axis 0 should be the filled values
311-
np.testing.assert_array_equal(out[0], np.array([[1.0, 0.5], [0.5, 4.0]]))
312-
# Channel 1 along axis 0 should be the mask
313-
np.testing.assert_array_equal(out[1], np.array([[1, 0], [0, 1]]))
309+
transform = bf.Adapter().nan_to_num(keys="test-2d", default_value=0.5, encode_mask=True)
310+
out = transform.forward(arr)
311+
np.testing.assert_array_equal(out["test-2d"], np.array([[1.0, 0.5], [0.5, 4.0]]))
312+
np.testing.assert_array_equal(out["_mask_test-2d"], np.array([[1, 0], [0, 1]]))

0 commit comments

Comments
 (0)