@@ -291,23 +291,22 @@ def test_log_det_jac_exceptions(random_data):
291291 assert np .allclose (forward_log_det_jac ["p" ], - inverse_log_det_jac )
292292
293293
294- def test_replace_nan ():
295- arr = {"test" : np .array ([1.0 , np .nan , 3.0 ]), "test-2d" : np . array ([[ 1.0 , np . nan ], [ np . nan , 4.0 ]]) }
294+ def test_nan_to_num ():
295+ arr = {"test" : np .array ([1.0 , np .nan , 3.0 ])}
296296 # test without mask
297- transform = bf .Adapter ().replace_nan (keys = "test" , default_value = - 1.0 , encode_mask = False )
297+ transform = bf .Adapter ().nan_to_num (keys = "test" , default_value = - 1.0 , encode_mask = False )
298298 out = transform .forward (arr )["test" ]
299299 np .testing .assert_array_equal (out , np .array ([1.0 , - 1.0 , 3.0 ]))
300300
301301 # test with mask
302- transform = bf .Adapter ().replace_nan (keys = "test" , default_value = 0.0 , encode_mask = True )
303- out = transform .forward (arr )["test" ]
304- np .testing .assert_array_equal (out , np .array ([[1.0 , 1.0 ], [0.0 , 0.0 ], [3.0 , 1.0 ]]))
302+ arr = {"test" : np .array ([1.0 , np .nan , 3.0 ]), "test-2d" : np .array ([[1.0 , np .nan ], [np .nan , 4.0 ]])}
303+ transform = bf .Adapter ().nan_to_num (keys = "test" , default_value = 0.0 , encode_mask = True )
304+ out = transform .forward (arr )
305+ np .testing .assert_array_equal (out ["test" ], np .array ([1.0 , 0.0 , 3.0 ]))
306+ np .testing .assert_array_equal (out ["_mask_test" ], np .array ([1.0 , 0.0 , 1.0 ]))
305307
306308 # test two-d array
307- transform = bf .Adapter ().replace_nan (keys = "test-2d" , default_value = 0.5 , encode_mask = True , axis = 0 )
308- out = transform .forward (arr )["test-2d" ]
309- # Original shape (2,2) -> new shape (2,2,2) when expanding at axis=0
310- # Channel 0 along axis 0 should be the filled values
311- np .testing .assert_array_equal (out [0 ], np .array ([[1.0 , 0.5 ], [0.5 , 4.0 ]]))
312- # Channel 1 along axis 0 should be the mask
313- np .testing .assert_array_equal (out [1 ], np .array ([[1 , 0 ], [0 , 1 ]]))
309+ transform = bf .Adapter ().nan_to_num (keys = "test-2d" , default_value = 0.5 , encode_mask = True )
310+ out = transform .forward (arr )
311+ np .testing .assert_array_equal (out ["test-2d" ], np .array ([[1.0 , 0.5 ], [0.5 , 4.0 ]]))
312+ np .testing .assert_array_equal (out ["_mask_test-2d" ], np .array ([[1 , 0 ], [0 , 1 ]]))
0 commit comments