Skip to content

Commit a2891f7

Browse files
ENH: Implement nan_to_num function (#398)
Co-authored-by: Lucas Colley <[email protected]>
1 parent b8426ef commit a2891f7

File tree

6 files changed

+268
-2
lines changed

6 files changed

+268
-2
lines changed

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
expand_dims
1717
isclose
1818
kron
19+
nan_to_num
1920
nunique
2021
one_hot
2122
pad

src/array_api_extra/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._delegation import isclose, one_hot, pad
3+
from ._delegation import isclose, nan_to_num, one_hot, pad
44
from ._lib._at import at
55
from ._lib._funcs import (
66
apply_where,
@@ -33,6 +33,7 @@
3333
"isclose",
3434
"kron",
3535
"lazy_apply",
36+
"nan_to_num",
3637
"nunique",
3738
"one_hot",
3839
"pad",

src/array_api_extra/_delegation.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from ._lib._utils._helpers import asarrays
1919
from ._lib._utils._typing import Array, DType
2020

21-
__all__ = ["isclose", "one_hot", "pad"]
21+
__all__ = ["isclose", "nan_to_num", "one_hot", "pad"]
2222

2323

2424
def isclose(
@@ -113,6 +113,85 @@ def isclose(
113113
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)
114114

115115

116+
def nan_to_num(
117+
x: Array | float | complex,
118+
/,
119+
*,
120+
fill_value: int | float = 0.0,
121+
xp: ModuleType | None = None,
122+
) -> Array:
123+
"""
124+
Replace NaN with zero and infinity with large finite numbers (default behaviour).
125+
126+
If `x` is inexact, NaN is replaced by zero or by the user defined value in the
127+
`fill_value` keyword, infinity is replaced by the largest finite floating
128+
point value representable by ``x.dtype``, and -infinity is replaced by the
129+
most negative finite floating point value representable by ``x.dtype``.
130+
131+
For complex dtypes, the above is applied to each of the real and
132+
imaginary components of `x` separately.
133+
134+
Parameters
135+
----------
136+
x : array | float | complex
137+
Input data.
138+
fill_value : int | float, optional
139+
Value to be used to fill NaN values. If no value is passed
140+
then NaN values will be replaced with 0.0.
141+
xp : array_namespace, optional
142+
The standard-compatible namespace for `x`. Default: infer.
143+
144+
Returns
145+
-------
146+
array
147+
`x`, with the non-finite values replaced.
148+
149+
See Also
150+
--------
151+
array_api.isnan : Shows which elements are Not a Number (NaN).
152+
153+
Examples
154+
--------
155+
>>> import array_api_extra as xpx
156+
>>> import array_api_strict as xp
157+
>>> xpx.nan_to_num(xp.inf)
158+
1.7976931348623157e+308
159+
>>> xpx.nan_to_num(-xp.inf)
160+
-1.7976931348623157e+308
161+
>>> xpx.nan_to_num(xp.nan)
162+
0.0
163+
>>> x = xp.asarray([xp.inf, -xp.inf, xp.nan, -128, 128])
164+
>>> xpx.nan_to_num(x)
165+
array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary
166+
-1.28000000e+002, 1.28000000e+002])
167+
>>> y = xp.asarray([complex(xp.inf, xp.nan), xp.nan, complex(xp.nan, xp.inf)])
168+
array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary
169+
-1.28000000e+002, 1.28000000e+002])
170+
>>> xpx.nan_to_num(y)
171+
array([ 1.79769313e+308 +0.00000000e+000j, # may vary
172+
0.00000000e+000 +0.00000000e+000j,
173+
0.00000000e+000 +1.79769313e+308j])
174+
"""
175+
if isinstance(fill_value, complex):
176+
msg = "Complex fill values are not supported."
177+
raise TypeError(msg)
178+
179+
xp = array_namespace(x) if xp is None else xp
180+
181+
# for scalars we want to output an array
182+
y = xp.asarray(x)
183+
184+
if (
185+
is_cupy_namespace(xp)
186+
or is_jax_namespace(xp)
187+
or is_numpy_namespace(xp)
188+
or is_torch_namespace(xp)
189+
):
190+
return xp.nan_to_num(y, nan=fill_value)
191+
192+
return _funcs.nan_to_num(y, fill_value=fill_value, xp=xp)
193+
194+
116195
def one_hot(
117196
x: Array,
118197
/,

src/array_api_extra/_lib/_funcs.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,47 @@ def kron(
738738
return xp.reshape(result, res_shape)
739739

740740

741+
def nan_to_num( # numpydoc ignore=PR01,RT01
742+
x: Array,
743+
/,
744+
fill_value: int | float = 0.0,
745+
*,
746+
xp: ModuleType,
747+
) -> Array:
748+
"""See docstring in `array_api_extra._delegation.py`."""
749+
750+
def perform_replacements( # numpydoc ignore=PR01,RT01
751+
x: Array,
752+
fill_value: int | float,
753+
xp: ModuleType,
754+
) -> Array:
755+
"""Internal function to perform the replacements."""
756+
x = xp.where(xp.isnan(x), fill_value, x)
757+
758+
# convert infinities to finite values
759+
finfo = xp.finfo(x.dtype)
760+
idx_posinf = xp.isinf(x) & ~xp.signbit(x)
761+
idx_neginf = xp.isinf(x) & xp.signbit(x)
762+
x = xp.where(idx_posinf, finfo.max, x)
763+
return xp.where(idx_neginf, finfo.min, x)
764+
765+
if xp.isdtype(x.dtype, "complex floating"):
766+
return perform_replacements(
767+
xp.real(x),
768+
fill_value,
769+
xp,
770+
) + 1j * perform_replacements(
771+
xp.imag(x),
772+
fill_value,
773+
xp,
774+
)
775+
776+
if xp.isdtype(x.dtype, "numeric"):
777+
return perform_replacements(x, fill_value, xp)
778+
779+
return x
780+
781+
741782
def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
742783
"""
743784
Count the number of unique elements in an array.

tests/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,11 @@ def device(
232232
if library == Backend.TORCH_GPU:
233233
return xp.device("cpu")
234234
return get_device(xp.empty(0))
235+
236+
237+
@pytest.fixture
238+
def infinity(library: Backend) -> float:
239+
"""Retrieve the positive infinity value for the given backend."""
240+
if library in (Backend.TORCH, Backend.TORCH_GPU):
241+
return 3.4028235e38
242+
return 1.7976931348623157e308

tests/test_funcs.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
expand_dims,
2222
isclose,
2323
kron,
24+
nan_to_num,
2425
nunique,
2526
one_hot,
2627
pad,
@@ -40,6 +41,7 @@
4041
lazy_xp_function(create_diagonal)
4142
lazy_xp_function(expand_dims)
4243
lazy_xp_function(kron)
44+
lazy_xp_function(nan_to_num)
4345
lazy_xp_function(nunique)
4446
lazy_xp_function(one_hot)
4547
lazy_xp_function(pad)
@@ -941,6 +943,140 @@ def test_xp(self, xp: ModuleType):
941943
xp_assert_equal(kron(a, b, xp=xp), k)
942944

943945

946+
class TestNanToNum:
947+
def test_bool(self, xp: ModuleType) -> None:
948+
a = xp.asarray([True])
949+
xp_assert_equal(nan_to_num(a, xp=xp), a)
950+
951+
def test_scalar_pos_inf(self, xp: ModuleType, infinity: float) -> None:
952+
a = xp.inf
953+
xp_assert_equal(nan_to_num(a, xp=xp), xp.asarray(infinity))
954+
955+
def test_scalar_neg_inf(self, xp: ModuleType, infinity: float) -> None:
956+
a = -xp.inf
957+
xp_assert_equal(nan_to_num(a, xp=xp), -xp.asarray(infinity))
958+
959+
def test_scalar_nan(self, xp: ModuleType) -> None:
960+
a = xp.nan
961+
xp_assert_equal(nan_to_num(a, xp=xp), xp.asarray(0.0))
962+
963+
def test_real(self, xp: ModuleType, infinity: float) -> None:
964+
a = xp.asarray([xp.inf, -xp.inf, xp.nan, -128, 128])
965+
xp_assert_equal(
966+
nan_to_num(a, xp=xp),
967+
xp.asarray(
968+
[
969+
infinity,
970+
-infinity,
971+
0.0,
972+
-128,
973+
128,
974+
]
975+
),
976+
)
977+
978+
def test_complex(self, xp: ModuleType, infinity: float) -> None:
979+
a = xp.asarray(
980+
[
981+
complex(xp.inf, xp.nan),
982+
xp.nan,
983+
complex(xp.nan, xp.inf),
984+
]
985+
)
986+
xp_assert_equal(
987+
nan_to_num(a),
988+
xp.asarray([complex(infinity, 0), complex(0, 0), complex(0, infinity)]),
989+
)
990+
991+
def test_empty_array(self, xp: ModuleType) -> None:
992+
a = xp.asarray([], dtype=xp.float32) # forced dtype due to torch
993+
xp_assert_equal(nan_to_num(a, xp=xp), a)
994+
assert xp.isdtype(nan_to_num(a, xp=xp).dtype, xp.float32)
995+
996+
@pytest.mark.parametrize(
997+
("in_vals", "fill_value", "out_vals"),
998+
[
999+
([1, 2, np.nan, 4], 3, [1.0, 2.0, 3.0, 4.0]),
1000+
([1, 2, np.nan, 4], 3.0, [1.0, 2.0, 3.0, 4.0]),
1001+
(
1002+
[
1003+
complex(1, 1),
1004+
complex(2, 2),
1005+
complex(np.nan, 0),
1006+
complex(4, 4),
1007+
],
1008+
3,
1009+
[
1010+
complex(1.0, 1.0),
1011+
complex(2.0, 2.0),
1012+
complex(3.0, 0.0),
1013+
complex(4.0, 4.0),
1014+
],
1015+
),
1016+
(
1017+
[
1018+
complex(1, 1),
1019+
complex(2, 2),
1020+
complex(0, np.nan),
1021+
complex(4, 4),
1022+
],
1023+
3.0,
1024+
[
1025+
complex(1.0, 1.0),
1026+
complex(2.0, 2.0),
1027+
complex(0.0, 3.0),
1028+
complex(4.0, 4.0),
1029+
],
1030+
),
1031+
(
1032+
[
1033+
complex(1, 1),
1034+
complex(2, 2),
1035+
complex(np.nan, np.nan),
1036+
complex(4, 4),
1037+
],
1038+
3.0,
1039+
[
1040+
complex(1.0, 1.0),
1041+
complex(2.0, 2.0),
1042+
complex(3.0, 3.0),
1043+
complex(4.0, 4.0),
1044+
],
1045+
),
1046+
],
1047+
)
1048+
def test_fill_value_success(
1049+
self,
1050+
xp: ModuleType,
1051+
in_vals: Array,
1052+
fill_value: int | float,
1053+
out_vals: Array,
1054+
) -> None:
1055+
a = xp.asarray(in_vals)
1056+
xp_assert_equal(
1057+
nan_to_num(a, fill_value=fill_value, xp=xp),
1058+
xp.asarray(out_vals),
1059+
)
1060+
1061+
def test_fill_value_failure(self, xp: ModuleType) -> None:
1062+
a = xp.asarray(
1063+
[
1064+
complex(1, 1),
1065+
complex(xp.nan, xp.nan),
1066+
complex(3, 3),
1067+
]
1068+
)
1069+
with pytest.raises(
1070+
TypeError,
1071+
match="Complex fill values are not supported",
1072+
):
1073+
_ = nan_to_num(
1074+
a,
1075+
fill_value=complex(2, 2), # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
1076+
xp=xp,
1077+
)
1078+
1079+
9441080
class TestNUnique:
9451081
def test_simple(self, xp: ModuleType):
9461082
a = xp.asarray([[1, 1], [0, 2], [2, 2]])

0 commit comments

Comments
 (0)