Skip to content

Commit 3c1cba6

Browse files
committed
Implement nan_to_num function
1 parent 5227063 commit 3c1cba6

File tree

5 files changed

+125
-2
lines changed

5 files changed

+125
-2
lines changed

docs/api-reference.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
expand_dims
1717
isclose
1818
kron
19+
nan_to_num
1920
nunique
2021
one_hot
2122
pad

src/array_api_extra/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._delegation import isclose, one_hot, pad
3+
from ._delegation import isclose, nan_to_num, one_hot, pad
44
from ._lib._at import at
55
from ._lib._funcs import (
66
apply_where,
@@ -33,6 +33,7 @@
3333
"isclose",
3434
"kron",
3535
"lazy_apply",
36+
"nan_to_num",
3637
"nunique",
3738
"one_hot",
3839
"pad",

src/array_api_extra/_delegation.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from ._lib._utils._helpers import asarrays
1919
from ._lib._utils._typing import Array, DType
2020

21-
__all__ = ["isclose", "one_hot", "pad"]
21+
__all__ = ["isclose", "nan_to_num", "one_hot", "pad"]
2222

2323

2424
def isclose(
@@ -113,6 +113,83 @@ def isclose(
113113
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)
114114

115115

116+
def nan_to_num(
117+
x: Array,
118+
/,
119+
*,
120+
fill_value: int | float | complex = 0.0,
121+
xp: ModuleType | None = None
122+
) -> Array:
123+
"""
124+
Replace NaN with zero and infinity with large finite numbers (default
125+
behaviour).
126+
127+
If `x` is inexact, NaN is replaced by zero or by the user defined value in
128+
`nan` keyword, infinity is replaced by the largest finite floating point
129+
values representable by ``x.dtype`` and -infinity is replaced by the most
130+
negative finite floating point values representable by ``x.dtype``.
131+
132+
For complex dtypes, the above is applied to each of the real and
133+
imaginary components of `x` separately.
134+
135+
If `x` is not inexact, then no replacements are made.
136+
137+
Parameters
138+
----------
139+
x : array
140+
Input data.
141+
fill_value : int, float, complex, optional
142+
Value to be used to fill NaN values. If no value is passed
143+
then NaN values will be replaced with 0.0.
144+
145+
Returns
146+
-------
147+
array
148+
`x`, with the non-finite values replaced.
149+
150+
See Also
151+
--------
152+
array_api.isnan : Shows which elements are Not a Number (NaN).
153+
154+
Examples
155+
--------
156+
>>> import array_api_extra as xpx
157+
>>> import array_api_strict as xp
158+
>>> xpx.nan_to_num(xp.inf)
159+
1.7976931348623157e+308
160+
>>> xpx.nan_to_num(-xp.inf)
161+
-1.7976931348623157e+308
162+
>>> xpx.nan_to_num(xp.nan)
163+
0.0
164+
>>> x = xp.array([xp.inf, -xp.inf, xp.nan, -128, 128])
165+
>>> xpx.nan_to_num(x)
166+
array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary
167+
-1.28000000e+002, 1.28000000e+002])
168+
>>> y = xp.array([complex(xp.inf, xp.nan), xp.nan, complex(xp.nan, xp.inf)])
169+
array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary
170+
-1.28000000e+002, 1.28000000e+002])
171+
>>> xpx.nan_to_num(y)
172+
array([ 1.79769313e+308 +0.00000000e+000j, # may vary
173+
0.00000000e+000 +0.00000000e+000j,
174+
0.00000000e+000 +1.79769313e+308j])
175+
"""
176+
if x.ndim == 0:
177+
msg = "x must be an array."
178+
raise TypeError(msg)
179+
180+
xp = array_namespace(x) if xp is None else xp
181+
182+
if (
183+
is_cupy_namespace(xp)
184+
or is_jax_namespace(xp)
185+
or is_numpy_namespace(xp)
186+
or is_torch_namespace(xp)
187+
):
188+
return xp.nan_to_num(x, nan=fill_value)
189+
190+
return _funcs.nan_to_num(x, fill_value=fill_value, xp=xp)
191+
192+
116193
def one_hot(
117194
x: Array,
118195
/,

src/array_api_extra/_lib/_funcs.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,48 @@ def kron(
738738
return xp.reshape(result, res_shape)
739739

740740

741+
def nan_to_num(
742+
x: Array,
743+
/,
744+
*,
745+
fill_value: int | float | complex = 0.0,
746+
xp: ModuleType | None = None,
747+
) -> Array:
748+
"""See docstring in `array_api_extra._delegation.py`."""
749+
xp = array_namespace(x) if xp is None else xp
750+
751+
def perform_replacements(
752+
x: Array,
753+
fill_value: int | float | complex,
754+
xp: ModuleType,
755+
) -> Array:
756+
"""Internal function to perform the replacements."""
757+
x = xp.where(xp.isnan(x), fill_value, x)
758+
759+
# convert infinities to finite values
760+
finfo = xp.finfo(x.dtype)
761+
idx_posinf = xp.isinf(x) & ~xp.signbit(x)
762+
idx_neginf = xp.isinf(x) & xp.signbit(x)
763+
x = xp.where(idx_posinf, x, finfo.max)
764+
return xp.where(idx_neginf, x, finfo.min)
765+
766+
if xp.isdtype(x.dtype, "complex floating"):
767+
return perform_replacements(
768+
x,
769+
fill_value,
770+
xp,
771+
) + 1j * perform_replacements(
772+
x,
773+
fill_value,
774+
xp,
775+
)
776+
777+
if xp.isdtype(x.dtype, "numeric"):
778+
return perform_replacements(x, fill_value, xp)
779+
780+
return x
781+
782+
741783
def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
742784
"""
743785
Count the number of unique elements in an array.

tests/test_funcs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
expand_dims,
2222
isclose,
2323
kron,
24+
nan_to_num,
2425
nunique,
2526
one_hot,
2627
pad,
@@ -40,6 +41,7 @@
4041
lazy_xp_function(create_diagonal)
4142
lazy_xp_function(expand_dims)
4243
lazy_xp_function(kron)
44+
lazy_xp_function(nan_to_num)
4345
lazy_xp_function(nunique)
4446
lazy_xp_function(one_hot)
4547
lazy_xp_function(pad)

0 commit comments

Comments
 (0)