diff --git a/docs/api-reference.md b/docs/api-reference.md index 38d0d26e..61e09e2d 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -16,6 +16,7 @@ expand_dims isclose kron + nan_to_num nunique one_hot pad diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index caa0526f..bcb0b3bd 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -1,6 +1,6 @@ """Extra array functions built on top of the array API standard.""" -from ._delegation import isclose, one_hot, pad +from ._delegation import isclose, nan_to_num, one_hot, pad from ._lib._at import at from ._lib._funcs import ( apply_where, @@ -33,6 +33,7 @@ "isclose", "kron", "lazy_apply", + "nan_to_num", "nunique", "one_hot", "pad", diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 756841c8..2c061e36 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -18,7 +18,7 @@ from ._lib._utils._helpers import asarrays from ._lib._utils._typing import Array, DType -__all__ = ["isclose", "one_hot", "pad"] +__all__ = ["isclose", "nan_to_num", "one_hot", "pad"] def isclose( @@ -113,6 +113,85 @@ def isclose( return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp) +def nan_to_num( + x: Array | float | complex, + /, + *, + fill_value: int | float = 0.0, + xp: ModuleType | None = None, +) -> Array: + """ + Replace NaN with zero and infinity with large finite numbers (default behaviour). + + If `x` is inexact, NaN is replaced by zero or by the user defined value in the + `fill_value` keyword, infinity is replaced by the largest finite floating + point value representable by ``x.dtype``, and -infinity is replaced by the + most negative finite floating point value representable by ``x.dtype``. + + For complex dtypes, the above is applied to each of the real and + imaginary components of `x` separately. + + Parameters + ---------- + x : array | float | complex + Input data. + fill_value : int | float, optional + Value to be used to fill NaN values. If no value is passed + then NaN values will be replaced with 0.0. + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer. + + Returns + ------- + array + `x`, with the non-finite values replaced. + + See Also + -------- + array_api.isnan : Shows which elements are Not a Number (NaN). + + Examples + -------- + >>> import array_api_extra as xpx + >>> import array_api_strict as xp + >>> xpx.nan_to_num(xp.inf) + 1.7976931348623157e+308 + >>> xpx.nan_to_num(-xp.inf) + -1.7976931348623157e+308 + >>> xpx.nan_to_num(xp.nan) + 0.0 + >>> x = xp.asarray([xp.inf, -xp.inf, xp.nan, -128, 128]) + >>> xpx.nan_to_num(x) + array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary + -1.28000000e+002, 1.28000000e+002]) + >>> y = xp.asarray([complex(xp.inf, xp.nan), xp.nan, complex(xp.nan, xp.inf)]) + array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary + -1.28000000e+002, 1.28000000e+002]) + >>> xpx.nan_to_num(y) + array([ 1.79769313e+308 +0.00000000e+000j, # may vary + 0.00000000e+000 +0.00000000e+000j, + 0.00000000e+000 +1.79769313e+308j]) + """ + if isinstance(fill_value, complex): + msg = "Complex fill values are not supported." + raise TypeError(msg) + + xp = array_namespace(x) if xp is None else xp + + # for scalars we want to output an array + y = xp.asarray(x) + + if ( + is_cupy_namespace(xp) + or is_jax_namespace(xp) + or is_numpy_namespace(xp) + or is_torch_namespace(xp) + ): + return xp.nan_to_num(y, nan=fill_value) + + return _funcs.nan_to_num(y, fill_value=fill_value, xp=xp) + + def one_hot( x: Array, /, diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 05db6251..cbcbe0ff 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -738,6 +738,47 @@ def kron( return xp.reshape(result, res_shape) +def nan_to_num( # numpydoc ignore=PR01,RT01 + x: Array, + /, + fill_value: int | float = 0.0, + *, + xp: ModuleType, +) -> Array: + """See docstring in `array_api_extra._delegation.py`.""" + + def perform_replacements( # numpydoc ignore=PR01,RT01 + x: Array, + fill_value: int | float, + xp: ModuleType, + ) -> Array: + """Internal function to perform the replacements.""" + x = xp.where(xp.isnan(x), fill_value, x) + + # convert infinities to finite values + finfo = xp.finfo(x.dtype) + idx_posinf = xp.isinf(x) & ~xp.signbit(x) + idx_neginf = xp.isinf(x) & xp.signbit(x) + x = xp.where(idx_posinf, finfo.max, x) + return xp.where(idx_neginf, finfo.min, x) + + if xp.isdtype(x.dtype, "complex floating"): + return perform_replacements( + xp.real(x), + fill_value, + xp, + ) + 1j * perform_replacements( + xp.imag(x), + fill_value, + xp, + ) + + if xp.isdtype(x.dtype, "numeric"): + return perform_replacements(x, fill_value, xp) + + return x + + def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array: """ Count the number of unique elements in an array. diff --git a/tests/conftest.py b/tests/conftest.py index 76fb7650..df703b97 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -232,3 +232,11 @@ def device( if library == Backend.TORCH_GPU: return xp.device("cpu") return get_device(xp.empty(0)) + + +@pytest.fixture +def infinity(library: Backend) -> float: + """Retrieve the positive infinity value for the given backend.""" + if library in (Backend.TORCH, Backend.TORCH_GPU): + return 3.4028235e38 + return 1.7976931348623157e308 diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 769b4119..8304b848 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -21,6 +21,7 @@ expand_dims, isclose, kron, + nan_to_num, nunique, one_hot, pad, @@ -40,6 +41,7 @@ lazy_xp_function(create_diagonal) lazy_xp_function(expand_dims) lazy_xp_function(kron) +lazy_xp_function(nan_to_num) lazy_xp_function(nunique) lazy_xp_function(one_hot) lazy_xp_function(pad) @@ -941,6 +943,140 @@ def test_xp(self, xp: ModuleType): xp_assert_equal(kron(a, b, xp=xp), k) +class TestNanToNum: + def test_bool(self, xp: ModuleType) -> None: + a = xp.asarray([True]) + xp_assert_equal(nan_to_num(a, xp=xp), a) + + def test_scalar_pos_inf(self, xp: ModuleType, infinity: float) -> None: + a = xp.inf + xp_assert_equal(nan_to_num(a, xp=xp), xp.asarray(infinity)) + + def test_scalar_neg_inf(self, xp: ModuleType, infinity: float) -> None: + a = -xp.inf + xp_assert_equal(nan_to_num(a, xp=xp), -xp.asarray(infinity)) + + def test_scalar_nan(self, xp: ModuleType) -> None: + a = xp.nan + xp_assert_equal(nan_to_num(a, xp=xp), xp.asarray(0.0)) + + def test_real(self, xp: ModuleType, infinity: float) -> None: + a = xp.asarray([xp.inf, -xp.inf, xp.nan, -128, 128]) + xp_assert_equal( + nan_to_num(a, xp=xp), + xp.asarray( + [ + infinity, + -infinity, + 0.0, + -128, + 128, + ] + ), + ) + + def test_complex(self, xp: ModuleType, infinity: float) -> None: + a = xp.asarray( + [ + complex(xp.inf, xp.nan), + xp.nan, + complex(xp.nan, xp.inf), + ] + ) + xp_assert_equal( + nan_to_num(a), + xp.asarray([complex(infinity, 0), complex(0, 0), complex(0, infinity)]), + ) + + def test_empty_array(self, xp: ModuleType) -> None: + a = xp.asarray([], dtype=xp.float32) # forced dtype due to torch + xp_assert_equal(nan_to_num(a, xp=xp), a) + assert xp.isdtype(nan_to_num(a, xp=xp).dtype, xp.float32) + + @pytest.mark.parametrize( + ("in_vals", "fill_value", "out_vals"), + [ + ([1, 2, np.nan, 4], 3, [1.0, 2.0, 3.0, 4.0]), + ([1, 2, np.nan, 4], 3.0, [1.0, 2.0, 3.0, 4.0]), + ( + [ + complex(1, 1), + complex(2, 2), + complex(np.nan, 0), + complex(4, 4), + ], + 3, + [ + complex(1.0, 1.0), + complex(2.0, 2.0), + complex(3.0, 0.0), + complex(4.0, 4.0), + ], + ), + ( + [ + complex(1, 1), + complex(2, 2), + complex(0, np.nan), + complex(4, 4), + ], + 3.0, + [ + complex(1.0, 1.0), + complex(2.0, 2.0), + complex(0.0, 3.0), + complex(4.0, 4.0), + ], + ), + ( + [ + complex(1, 1), + complex(2, 2), + complex(np.nan, np.nan), + complex(4, 4), + ], + 3.0, + [ + complex(1.0, 1.0), + complex(2.0, 2.0), + complex(3.0, 3.0), + complex(4.0, 4.0), + ], + ), + ], + ) + def test_fill_value_success( + self, + xp: ModuleType, + in_vals: Array, + fill_value: int | float, + out_vals: Array, + ) -> None: + a = xp.asarray(in_vals) + xp_assert_equal( + nan_to_num(a, fill_value=fill_value, xp=xp), + xp.asarray(out_vals), + ) + + def test_fill_value_failure(self, xp: ModuleType) -> None: + a = xp.asarray( + [ + complex(1, 1), + complex(xp.nan, xp.nan), + complex(3, 3), + ] + ) + with pytest.raises( + TypeError, + match="Complex fill values are not supported", + ): + _ = nan_to_num( + a, + fill_value=complex(2, 2), # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + xp=xp, + ) + + class TestNUnique: def test_simple(self, xp: ModuleType): a = xp.asarray([[1, 1], [0, 2], [2, 2]])