Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
expand_dims
isclose
kron
nan_to_num
nunique
one_hot
pad
Expand Down
3 changes: 2 additions & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Extra array functions built on top of the array API standard."""

from ._delegation import isclose, one_hot, pad
from ._delegation import isclose, nan_to_num, one_hot, pad
from ._lib._at import at
from ._lib._funcs import (
apply_where,
Expand Down Expand Up @@ -33,6 +33,7 @@
"isclose",
"kron",
"lazy_apply",
"nan_to_num",
"nunique",
"one_hot",
"pad",
Expand Down
78 changes: 77 additions & 1 deletion src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ._lib._utils._helpers import asarrays
from ._lib._utils._typing import Array, DType

__all__ = ["isclose", "one_hot", "pad"]
__all__ = ["isclose", "nan_to_num", "one_hot", "pad"]


def isclose(
Expand Down Expand Up @@ -113,6 +113,82 @@ def isclose(
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)


def nan_to_num(
x: Array | float | complex,
/,
*,
fill_value: int | float | complex = 0.0,
xp: ModuleType | None = None
) -> Array:
"""
Replace NaN with zero and infinity with large finite numbers (default
behaviour).

If `x` is inexact, NaN is replaced by zero or by the user defined value in
`nan` keyword, infinity is replaced by the largest finite floating point
values representable by ``x.dtype`` and -infinity is replaced by the most
negative finite floating point values representable by ``x.dtype``.

For complex dtypes, the above is applied to each of the real and
imaginary components of `x` separately.

If `x` is not inexact, then no replacements are made.

Parameters
----------
x : array, float, complex
Input data.
fill_value : int, float, complex, optional
Value to be used to fill NaN values. If no value is passed
then NaN values will be replaced with 0.0.

Returns
-------
array
`x`, with the non-finite values replaced.

See Also
--------
array_api.isnan : Shows which elements are Not a Number (NaN).

Examples
--------
>>> import array_api_extra as xpx
>>> import array_api_strict as xp
>>> xpx.nan_to_num(xp.inf)
1.7976931348623157e+308
>>> xpx.nan_to_num(-xp.inf)
-1.7976931348623157e+308
>>> xpx.nan_to_num(xp.nan)
0.0
>>> x = xp.array([xp.inf, -xp.inf, xp.nan, -128, 128])
>>> xpx.nan_to_num(x)
array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary
-1.28000000e+002, 1.28000000e+002])
>>> y = xp.array([complex(xp.inf, xp.nan), xp.nan, complex(xp.nan, xp.inf)])
array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary
-1.28000000e+002, 1.28000000e+002])
>>> xpx.nan_to_num(y)
array([ 1.79769313e+308 +0.00000000e+000j, # may vary
0.00000000e+000 +0.00000000e+000j,
0.00000000e+000 +1.79769313e+308j])
"""
xp = array_namespace(x) if xp is None else xp

# for scalars we want to output an array
y = xp.asarray(x)

if (
is_cupy_namespace(xp)
or is_jax_namespace(xp)
or is_numpy_namespace(xp)
or is_torch_namespace(xp)
):
return xp.nan_to_num(y, nan=fill_value)

return _funcs.nan_to_num(y, fill_value=fill_value, xp=xp)


def one_hot(
x: Array,
/,
Expand Down
42 changes: 42 additions & 0 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,48 @@ def kron(
return xp.reshape(result, res_shape)


def nan_to_num(
x: Array,
/,
*,
fill_value: int | float | complex = 0.0,
xp: ModuleType | None = None,
) -> Array:
"""See docstring in `array_api_extra._delegation.py`."""
xp = array_namespace(x) if xp is None else xp

def perform_replacements(
x: Array,
fill_value: int | float | complex,
xp: ModuleType,
) -> Array:
"""Internal function to perform the replacements."""
x = xp.where(xp.isnan(x), fill_value, x)

# convert infinities to finite values
finfo = xp.finfo(x.dtype)
idx_posinf = xp.isinf(x) & ~xp.signbit(x)
idx_neginf = xp.isinf(x) & xp.signbit(x)
x = xp.where(idx_posinf, finfo.max, x)
return xp.where(idx_neginf, finfo.min, x)

if xp.isdtype(x.dtype, "complex floating"):
return perform_replacements(
xp.real(x),
fill_value,
xp,
) + 1j * perform_replacements(
xp.imag(x),
fill_value,
xp,
)

if xp.isdtype(x.dtype, "numeric"):
return perform_replacements(x, fill_value, xp)

return x


def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
"""
Count the number of unique elements in an array.
Expand Down
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,10 @@ def device(
if library == Backend.TORCH_GPU:
return xp.device("cpu")
return get_device(xp.empty(0))

@pytest.fixture
def infinity(library: Backend) -> float:
"""Retrieve the positive infinity value for the given backend."""
if library in (Backend.TORCH, Backend.TORCH_GPU):
return 3.4028235e+38
return 1.7976931348623157e+308
54 changes: 54 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
expand_dims,
isclose,
kron,
nan_to_num,
nunique,
one_hot,
pad,
Expand All @@ -40,6 +41,7 @@
lazy_xp_function(create_diagonal)
lazy_xp_function(expand_dims)
lazy_xp_function(kron)
lazy_xp_function(nan_to_num)
lazy_xp_function(nunique)
lazy_xp_function(one_hot)
lazy_xp_function(pad)
Expand Down Expand Up @@ -941,6 +943,58 @@ def test_xp(self, xp: ModuleType):
xp_assert_equal(kron(a, b, xp=xp), k)


class TestNumToNan:
def test_bool(self, xp: ModuleType) -> None:
a = xp.asarray([True])
xp_assert_equal(nan_to_num(a), a)

def test_scalar_pos_inf(self, xp: ModuleType, infinity: float) -> None:
a = xp.inf
xp_assert_equal(nan_to_num(a, xp=xp), xp.asarray(infinity))

def test_scalar_neg_inf(self, xp: ModuleType, infinity: float) -> None:
a = -xp.inf
xp_assert_equal(nan_to_num(a, xp=xp), -xp.asarray(infinity))

def test_scalar_nan(self, xp: ModuleType) -> None:
a = xp.nan
xp_assert_equal(nan_to_num(a, xp=xp), xp.asarray(0.0))

def test_real(self, xp: ModuleType, infinity: float) -> None:
a = xp.asarray([xp.inf, -xp.inf, xp.nan, -128, 128])
xp_assert_equal(
nan_to_num(a),
xp.asarray(
[
infinity,
-infinity,
0.0,
-128,
128,
]
),
)

def test_complex(self, xp: ModuleType, infinity: float) -> None:
a = xp.asarray(
[
complex(xp.inf, xp.nan),
xp.nan,
complex(xp.nan, xp.inf),
]
)
xp_assert_equal(
nan_to_num(a),
xp.asarray(
[
infinity + 0j,
0 + 0j,
0 + 1j * infinity
]
),
)


class TestNUnique:
def test_simple(self, xp: ModuleType):
a = xp.asarray([[1, 1], [0, 2], [2, 2]])
Expand Down
Loading