|
21 | 21 | expand_dims, |
22 | 22 | isclose, |
23 | 23 | kron, |
| 24 | + nan_to_num, |
24 | 25 | nunique, |
25 | 26 | one_hot, |
26 | 27 | pad, |
|
40 | 41 | lazy_xp_function(create_diagonal) |
41 | 42 | lazy_xp_function(expand_dims) |
42 | 43 | lazy_xp_function(kron) |
| 44 | +lazy_xp_function(nan_to_num) |
43 | 45 | lazy_xp_function(nunique) |
44 | 46 | lazy_xp_function(one_hot) |
45 | 47 | lazy_xp_function(pad) |
@@ -941,6 +943,140 @@ def test_xp(self, xp: ModuleType): |
941 | 943 | xp_assert_equal(kron(a, b, xp=xp), k) |
942 | 944 |
|
943 | 945 |
|
| 946 | +class TestNanToNum: |
| 947 | + def test_bool(self, xp: ModuleType) -> None: |
| 948 | + a = xp.asarray([True]) |
| 949 | + xp_assert_equal(nan_to_num(a, xp=xp), a) |
| 950 | + |
| 951 | + def test_scalar_pos_inf(self, xp: ModuleType, infinity: float) -> None: |
| 952 | + a = xp.inf |
| 953 | + xp_assert_equal(nan_to_num(a, xp=xp), xp.asarray(infinity)) |
| 954 | + |
| 955 | + def test_scalar_neg_inf(self, xp: ModuleType, infinity: float) -> None: |
| 956 | + a = -xp.inf |
| 957 | + xp_assert_equal(nan_to_num(a, xp=xp), -xp.asarray(infinity)) |
| 958 | + |
| 959 | + def test_scalar_nan(self, xp: ModuleType) -> None: |
| 960 | + a = xp.nan |
| 961 | + xp_assert_equal(nan_to_num(a, xp=xp), xp.asarray(0.0)) |
| 962 | + |
| 963 | + def test_real(self, xp: ModuleType, infinity: float) -> None: |
| 964 | + a = xp.asarray([xp.inf, -xp.inf, xp.nan, -128, 128]) |
| 965 | + xp_assert_equal( |
| 966 | + nan_to_num(a, xp=xp), |
| 967 | + xp.asarray( |
| 968 | + [ |
| 969 | + infinity, |
| 970 | + -infinity, |
| 971 | + 0.0, |
| 972 | + -128, |
| 973 | + 128, |
| 974 | + ] |
| 975 | + ), |
| 976 | + ) |
| 977 | + |
| 978 | + def test_complex(self, xp: ModuleType, infinity: float) -> None: |
| 979 | + a = xp.asarray( |
| 980 | + [ |
| 981 | + complex(xp.inf, xp.nan), |
| 982 | + xp.nan, |
| 983 | + complex(xp.nan, xp.inf), |
| 984 | + ] |
| 985 | + ) |
| 986 | + xp_assert_equal( |
| 987 | + nan_to_num(a), |
| 988 | + xp.asarray([complex(infinity, 0), complex(0, 0), complex(0, infinity)]), |
| 989 | + ) |
| 990 | + |
| 991 | + def test_empty_array(self, xp: ModuleType) -> None: |
| 992 | + a = xp.asarray([], dtype=xp.float32) # forced dtype due to torch |
| 993 | + xp_assert_equal(nan_to_num(a, xp=xp), a) |
| 994 | + assert xp.isdtype(nan_to_num(a, xp=xp).dtype, xp.float32) |
| 995 | + |
| 996 | + @pytest.mark.parametrize( |
| 997 | + ("in_vals", "fill_value", "out_vals"), |
| 998 | + [ |
| 999 | + ([1, 2, np.nan, 4], 3, [1.0, 2.0, 3.0, 4.0]), |
| 1000 | + ([1, 2, np.nan, 4], 3.0, [1.0, 2.0, 3.0, 4.0]), |
| 1001 | + ( |
| 1002 | + [ |
| 1003 | + complex(1, 1), |
| 1004 | + complex(2, 2), |
| 1005 | + complex(np.nan, 0), |
| 1006 | + complex(4, 4), |
| 1007 | + ], |
| 1008 | + 3, |
| 1009 | + [ |
| 1010 | + complex(1.0, 1.0), |
| 1011 | + complex(2.0, 2.0), |
| 1012 | + complex(3.0, 0.0), |
| 1013 | + complex(4.0, 4.0), |
| 1014 | + ], |
| 1015 | + ), |
| 1016 | + ( |
| 1017 | + [ |
| 1018 | + complex(1, 1), |
| 1019 | + complex(2, 2), |
| 1020 | + complex(0, np.nan), |
| 1021 | + complex(4, 4), |
| 1022 | + ], |
| 1023 | + 3.0, |
| 1024 | + [ |
| 1025 | + complex(1.0, 1.0), |
| 1026 | + complex(2.0, 2.0), |
| 1027 | + complex(0.0, 3.0), |
| 1028 | + complex(4.0, 4.0), |
| 1029 | + ], |
| 1030 | + ), |
| 1031 | + ( |
| 1032 | + [ |
| 1033 | + complex(1, 1), |
| 1034 | + complex(2, 2), |
| 1035 | + complex(np.nan, np.nan), |
| 1036 | + complex(4, 4), |
| 1037 | + ], |
| 1038 | + 3.0, |
| 1039 | + [ |
| 1040 | + complex(1.0, 1.0), |
| 1041 | + complex(2.0, 2.0), |
| 1042 | + complex(3.0, 3.0), |
| 1043 | + complex(4.0, 4.0), |
| 1044 | + ], |
| 1045 | + ), |
| 1046 | + ], |
| 1047 | + ) |
| 1048 | + def test_fill_value_success( |
| 1049 | + self, |
| 1050 | + xp: ModuleType, |
| 1051 | + in_vals: Array, |
| 1052 | + fill_value: int | float, |
| 1053 | + out_vals: Array, |
| 1054 | + ) -> None: |
| 1055 | + a = xp.asarray(in_vals) |
| 1056 | + xp_assert_equal( |
| 1057 | + nan_to_num(a, fill_value=fill_value, xp=xp), |
| 1058 | + xp.asarray(out_vals), |
| 1059 | + ) |
| 1060 | + |
| 1061 | + def test_fill_value_failure(self, xp: ModuleType) -> None: |
| 1062 | + a = xp.asarray( |
| 1063 | + [ |
| 1064 | + complex(1, 1), |
| 1065 | + complex(xp.nan, xp.nan), |
| 1066 | + complex(3, 3), |
| 1067 | + ] |
| 1068 | + ) |
| 1069 | + with pytest.raises( |
| 1070 | + TypeError, |
| 1071 | + match="Complex fill values are not supported", |
| 1072 | + ): |
| 1073 | + _ = nan_to_num( |
| 1074 | + a, |
| 1075 | + fill_value=complex(2, 2), # type: ignore[arg-type] # pyright: ignore[reportArgumentType] |
| 1076 | + xp=xp, |
| 1077 | + ) |
| 1078 | + |
| 1079 | + |
944 | 1080 | class TestNUnique: |
945 | 1081 | def test_simple(self, xp: ModuleType): |
946 | 1082 | a = xp.asarray([[1, 1], [0, 2], [2, 2]]) |
|
0 commit comments