Skip to content

Commit 1ae1abf

Browse files
committed
Add tests specified in docstring
1 parent 3c1cba6 commit 1ae1abf

File tree

4 files changed

+70
-12
lines changed

4 files changed

+70
-12
lines changed

src/array_api_extra/_delegation.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def isclose(
114114

115115

116116
def nan_to_num(
117-
x: Array,
117+
x: Array | float | complex,
118118
/,
119119
*,
120120
fill_value: int | float | complex = 0.0,
@@ -136,7 +136,7 @@ def nan_to_num(
136136
137137
Parameters
138138
----------
139-
x : array
139+
x : array, float, complex
140140
Input data.
141141
fill_value : int, float, complex, optional
142142
Value to be used to fill NaN values. If no value is passed
@@ -173,21 +173,20 @@ def nan_to_num(
173173
0.00000000e+000 +0.00000000e+000j,
174174
0.00000000e+000 +1.79769313e+308j])
175175
"""
176-
if x.ndim == 0:
177-
msg = "x must be an array."
178-
raise TypeError(msg)
179-
180176
xp = array_namespace(x) if xp is None else xp
181177

178+
# for scalars we want to output an array
179+
y = xp.asarray(x)
180+
182181
if (
183182
is_cupy_namespace(xp)
184183
or is_jax_namespace(xp)
185184
or is_numpy_namespace(xp)
186185
or is_torch_namespace(xp)
187186
):
188-
return xp.nan_to_num(x, nan=fill_value)
187+
return xp.nan_to_num(y, nan=fill_value)
189188

190-
return _funcs.nan_to_num(x, fill_value=fill_value, xp=xp)
189+
return _funcs.nan_to_num(y, fill_value=fill_value, xp=xp)
191190

192191

193192
def one_hot(

src/array_api_extra/_lib/_funcs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -760,16 +760,16 @@ def perform_replacements(
760760
finfo = xp.finfo(x.dtype)
761761
idx_posinf = xp.isinf(x) & ~xp.signbit(x)
762762
idx_neginf = xp.isinf(x) & xp.signbit(x)
763-
x = xp.where(idx_posinf, x, finfo.max)
764-
return xp.where(idx_neginf, x, finfo.min)
763+
x = xp.where(idx_posinf, finfo.max, x)
764+
return xp.where(idx_neginf, finfo.min, x)
765765

766766
if xp.isdtype(x.dtype, "complex floating"):
767767
return perform_replacements(
768-
x,
768+
xp.real(x),
769769
fill_value,
770770
xp,
771771
) + 1j * perform_replacements(
772-
x,
772+
xp.imag(x),
773773
fill_value,
774774
xp,
775775
)

tests/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,10 @@ def device(
232232
if library == Backend.TORCH_GPU:
233233
return xp.device("cpu")
234234
return get_device(xp.empty(0))
235+
236+
@pytest.fixture
237+
def infinity(library: Backend) -> float:
238+
"""Retrieve the positive infinity value for the given backend."""
239+
if library in (Backend.TORCH, Backend.TORCH_GPU):
240+
return 3.4028235e+38
241+
return 1.7976931348623157e+308

tests/test_funcs.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,58 @@ def test_xp(self, xp: ModuleType):
943943
xp_assert_equal(kron(a, b, xp=xp), k)
944944

945945

946+
class TestNumToNan:
947+
def test_bool(self, xp: ModuleType) -> None:
948+
a = xp.asarray([True])
949+
xp_assert_equal(nan_to_num(a), a)
950+
951+
def test_scalar_pos_inf(self, xp: ModuleType, infinity: float) -> None:
952+
a = xp.inf
953+
xp_assert_equal(nan_to_num(a, xp=xp), xp.asarray(infinity))
954+
955+
def test_scalar_neg_inf(self, xp: ModuleType, infinity: float) -> None:
956+
a = -xp.inf
957+
xp_assert_equal(nan_to_num(a, xp=xp), -xp.asarray(infinity))
958+
959+
def test_scalar_nan(self, xp: ModuleType) -> None:
960+
a = xp.nan
961+
xp_assert_equal(nan_to_num(a, xp=xp), xp.asarray(0.0))
962+
963+
def test_real(self, xp: ModuleType, infinity: float) -> None:
964+
a = xp.asarray([xp.inf, -xp.inf, xp.nan, -128, 128])
965+
xp_assert_equal(
966+
nan_to_num(a),
967+
xp.asarray(
968+
[
969+
infinity,
970+
-infinity,
971+
0.0,
972+
-128,
973+
128,
974+
]
975+
),
976+
)
977+
978+
def test_complex(self, xp: ModuleType, infinity: float) -> None:
979+
a = xp.asarray(
980+
[
981+
complex(xp.inf, xp.nan),
982+
xp.nan,
983+
complex(xp.nan, xp.inf),
984+
]
985+
)
986+
xp_assert_equal(
987+
nan_to_num(a),
988+
xp.asarray(
989+
[
990+
infinity + 0j,
991+
0 + 0j,
992+
0 + 1j * infinity
993+
]
994+
),
995+
)
996+
997+
946998
class TestNUnique:
947999
def test_simple(self, xp: ModuleType):
9481000
a = xp.asarray([[1, 1], [0, 2], [2, 2]])

0 commit comments

Comments
 (0)