From 3c1cba6482cff42528487940ffd9eb1dd28d4350 Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Fri, 22 Aug 2025 15:27:08 +0200 Subject: [PATCH 01/21] Implement `nan_to_num` function --- docs/api-reference.md | 1 + src/array_api_extra/__init__.py | 3 +- src/array_api_extra/_delegation.py | 79 +++++++++++++++++++++++++++++- src/array_api_extra/_lib/_funcs.py | 42 ++++++++++++++++ tests/test_funcs.py | 2 + 5 files changed, 125 insertions(+), 2 deletions(-) diff --git a/docs/api-reference.md b/docs/api-reference.md index 38d0d26e..61e09e2d 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -16,6 +16,7 @@ expand_dims isclose kron + nan_to_num nunique one_hot pad diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index ddfc715e..7863f335 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -1,6 +1,6 @@ """Extra array functions built on top of the array API standard.""" -from ._delegation import isclose, one_hot, pad +from ._delegation import isclose, nan_to_num, one_hot, pad from ._lib._at import at from ._lib._funcs import ( apply_where, @@ -33,6 +33,7 @@ "isclose", "kron", "lazy_apply", + "nan_to_num", "nunique", "one_hot", "pad", diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 756841c8..8ad8970a 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -18,7 +18,7 @@ from ._lib._utils._helpers import asarrays from ._lib._utils._typing import Array, DType -__all__ = ["isclose", "one_hot", "pad"] +__all__ = ["isclose", "nan_to_num", "one_hot", "pad"] def isclose( @@ -113,6 +113,83 @@ def isclose( return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp) +def nan_to_num( + x: Array, + /, + *, + fill_value: int | float | complex = 0.0, + xp: ModuleType | None = None +) -> Array: + """ + Replace NaN with zero and infinity with large finite numbers (default + behaviour). + + If `x` is inexact, NaN is replaced by zero or by the user defined value in + `nan` keyword, infinity is replaced by the largest finite floating point + values representable by ``x.dtype`` and -infinity is replaced by the most + negative finite floating point values representable by ``x.dtype``. + + For complex dtypes, the above is applied to each of the real and + imaginary components of `x` separately. + + If `x` is not inexact, then no replacements are made. + + Parameters + ---------- + x : array + Input data. + fill_value : int, float, complex, optional + Value to be used to fill NaN values. If no value is passed + then NaN values will be replaced with 0.0. + + Returns + ------- + array + `x`, with the non-finite values replaced. + + See Also + -------- + array_api.isnan : Shows which elements are Not a Number (NaN). + + Examples + -------- + >>> import array_api_extra as xpx + >>> import array_api_strict as xp + >>> xpx.nan_to_num(xp.inf) + 1.7976931348623157e+308 + >>> xpx.nan_to_num(-xp.inf) + -1.7976931348623157e+308 + >>> xpx.nan_to_num(xp.nan) + 0.0 + >>> x = xp.array([xp.inf, -xp.inf, xp.nan, -128, 128]) + >>> xpx.nan_to_num(x) + array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary + -1.28000000e+002, 1.28000000e+002]) + >>> y = xp.array([complex(xp.inf, xp.nan), xp.nan, complex(xp.nan, xp.inf)]) + array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary + -1.28000000e+002, 1.28000000e+002]) + >>> xpx.nan_to_num(y) + array([ 1.79769313e+308 +0.00000000e+000j, # may vary + 0.00000000e+000 +0.00000000e+000j, + 0.00000000e+000 +1.79769313e+308j]) + """ + if x.ndim == 0: + msg = "x must be an array." + raise TypeError(msg) + + xp = array_namespace(x) if xp is None else xp + + if ( + is_cupy_namespace(xp) + or is_jax_namespace(xp) + or is_numpy_namespace(xp) + or is_torch_namespace(xp) + ): + return xp.nan_to_num(x, nan=fill_value) + + return _funcs.nan_to_num(x, fill_value=fill_value, xp=xp) + + def one_hot( x: Array, /, diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 05db6251..a4e46a28 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -738,6 +738,48 @@ def kron( return xp.reshape(result, res_shape) +def nan_to_num( + x: Array, + /, + *, + fill_value: int | float | complex = 0.0, + xp: ModuleType | None = None, +) -> Array: + """See docstring in `array_api_extra._delegation.py`.""" + xp = array_namespace(x) if xp is None else xp + + def perform_replacements( + x: Array, + fill_value: int | float | complex, + xp: ModuleType, + ) -> Array: + """Internal function to perform the replacements.""" + x = xp.where(xp.isnan(x), fill_value, x) + + # convert infinities to finite values + finfo = xp.finfo(x.dtype) + idx_posinf = xp.isinf(x) & ~xp.signbit(x) + idx_neginf = xp.isinf(x) & xp.signbit(x) + x = xp.where(idx_posinf, x, finfo.max) + return xp.where(idx_neginf, x, finfo.min) + + if xp.isdtype(x.dtype, "complex floating"): + return perform_replacements( + x, + fill_value, + xp, + ) + 1j * perform_replacements( + x, + fill_value, + xp, + ) + + if xp.isdtype(x.dtype, "numeric"): + return perform_replacements(x, fill_value, xp) + + return x + + def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array: """ Count the number of unique elements in an array. diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 769b4119..acfacc22 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -21,6 +21,7 @@ expand_dims, isclose, kron, + nan_to_num, nunique, one_hot, pad, @@ -40,6 +41,7 @@ lazy_xp_function(create_diagonal) lazy_xp_function(expand_dims) lazy_xp_function(kron) +lazy_xp_function(nan_to_num) lazy_xp_function(nunique) lazy_xp_function(one_hot) lazy_xp_function(pad) From 1ae1abfb2256ceffcb950952af58295f18627c05 Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Fri, 22 Aug 2025 17:06:47 +0200 Subject: [PATCH 02/21] Add tests specified in docstring --- src/array_api_extra/_delegation.py | 15 ++++----- src/array_api_extra/_lib/_funcs.py | 8 ++--- tests/conftest.py | 7 ++++ tests/test_funcs.py | 52 ++++++++++++++++++++++++++++++ 4 files changed, 70 insertions(+), 12 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 8ad8970a..fb3ca234 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -114,7 +114,7 @@ def isclose( def nan_to_num( - x: Array, + x: Array | float | complex, /, *, fill_value: int | float | complex = 0.0, @@ -136,7 +136,7 @@ def nan_to_num( Parameters ---------- - x : array + x : array, float, complex Input data. fill_value : int, float, complex, optional Value to be used to fill NaN values. If no value is passed @@ -173,21 +173,20 @@ def nan_to_num( 0.00000000e+000 +0.00000000e+000j, 0.00000000e+000 +1.79769313e+308j]) """ - if x.ndim == 0: - msg = "x must be an array." - raise TypeError(msg) - xp = array_namespace(x) if xp is None else xp + # for scalars we want to output an array + y = xp.asarray(x) + if ( is_cupy_namespace(xp) or is_jax_namespace(xp) or is_numpy_namespace(xp) or is_torch_namespace(xp) ): - return xp.nan_to_num(x, nan=fill_value) + return xp.nan_to_num(y, nan=fill_value) - return _funcs.nan_to_num(x, fill_value=fill_value, xp=xp) + return _funcs.nan_to_num(y, fill_value=fill_value, xp=xp) def one_hot( diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index a4e46a28..bdc0dc57 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -760,16 +760,16 @@ def perform_replacements( finfo = xp.finfo(x.dtype) idx_posinf = xp.isinf(x) & ~xp.signbit(x) idx_neginf = xp.isinf(x) & xp.signbit(x) - x = xp.where(idx_posinf, x, finfo.max) - return xp.where(idx_neginf, x, finfo.min) + x = xp.where(idx_posinf, finfo.max, x) + return xp.where(idx_neginf, finfo.min, x) if xp.isdtype(x.dtype, "complex floating"): return perform_replacements( - x, + xp.real(x), fill_value, xp, ) + 1j * perform_replacements( - x, + xp.imag(x), fill_value, xp, ) diff --git a/tests/conftest.py b/tests/conftest.py index 76fb7650..c87e6104 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -232,3 +232,10 @@ def device( if library == Backend.TORCH_GPU: return xp.device("cpu") return get_device(xp.empty(0)) + +@pytest.fixture +def infinity(library: Backend) -> float: + """Retrieve the positive infinity value for the given backend.""" + if library in (Backend.TORCH, Backend.TORCH_GPU): + return 3.4028235e+38 + return 1.7976931348623157e+308 diff --git a/tests/test_funcs.py b/tests/test_funcs.py index acfacc22..45692d0f 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -943,6 +943,58 @@ def test_xp(self, xp: ModuleType): xp_assert_equal(kron(a, b, xp=xp), k) +class TestNumToNan: + def test_bool(self, xp: ModuleType) -> None: + a = xp.asarray([True]) + xp_assert_equal(nan_to_num(a), a) + + def test_scalar_pos_inf(self, xp: ModuleType, infinity: float) -> None: + a = xp.inf + xp_assert_equal(nan_to_num(a, xp=xp), xp.asarray(infinity)) + + def test_scalar_neg_inf(self, xp: ModuleType, infinity: float) -> None: + a = -xp.inf + xp_assert_equal(nan_to_num(a, xp=xp), -xp.asarray(infinity)) + + def test_scalar_nan(self, xp: ModuleType) -> None: + a = xp.nan + xp_assert_equal(nan_to_num(a, xp=xp), xp.asarray(0.0)) + + def test_real(self, xp: ModuleType, infinity: float) -> None: + a = xp.asarray([xp.inf, -xp.inf, xp.nan, -128, 128]) + xp_assert_equal( + nan_to_num(a), + xp.asarray( + [ + infinity, + -infinity, + 0.0, + -128, + 128, + ] + ), + ) + + def test_complex(self, xp: ModuleType, infinity: float) -> None: + a = xp.asarray( + [ + complex(xp.inf, xp.nan), + xp.nan, + complex(xp.nan, xp.inf), + ] + ) + xp_assert_equal( + nan_to_num(a), + xp.asarray( + [ + infinity + 0j, + 0 + 0j, + 0 + 1j * infinity + ] + ), + ) + + class TestNUnique: def test_simple(self, xp: ModuleType): a = xp.asarray([[1, 1], [0, 2], [2, 2]]) From cfaab8b86fd9e64d658b50a3e6947b1430368e32 Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Fri, 22 Aug 2025 17:19:40 +0200 Subject: [PATCH 03/21] Make `xp` mandatory --- src/array_api_extra/_lib/_funcs.py | 4 ++-- tests/test_funcs.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index bdc0dc57..f4a51569 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -741,9 +741,9 @@ def kron( def nan_to_num( x: Array, /, - *, fill_value: int | float | complex = 0.0, - xp: ModuleType | None = None, + *, + xp: ModuleType, ) -> Array: """See docstring in `array_api_extra._delegation.py`.""" xp = array_namespace(x) if xp is None else xp diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 45692d0f..2c875056 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -946,7 +946,7 @@ def test_xp(self, xp: ModuleType): class TestNumToNan: def test_bool(self, xp: ModuleType) -> None: a = xp.asarray([True]) - xp_assert_equal(nan_to_num(a), a) + xp_assert_equal(nan_to_num(a, xp=xp), a) def test_scalar_pos_inf(self, xp: ModuleType, infinity: float) -> None: a = xp.inf @@ -963,7 +963,7 @@ def test_scalar_nan(self, xp: ModuleType) -> None: def test_real(self, xp: ModuleType, infinity: float) -> None: a = xp.asarray([xp.inf, -xp.inf, xp.nan, -128, 128]) xp_assert_equal( - nan_to_num(a), + nan_to_num(a, xp=xp), xp.asarray( [ infinity, From ef799ed7b346712617c8807ef29792d7d9765a05 Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Fri, 22 Aug 2025 17:30:08 +0200 Subject: [PATCH 04/21] Fix linting --- src/array_api_extra/_delegation.py | 7 ++++--- src/array_api_extra/_lib/_funcs.py | 5 ++--- tests/conftest.py | 5 +++-- tests/test_funcs.py | 8 +------- 4 files changed, 10 insertions(+), 15 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index fb3ca234..6580d2c9 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -118,11 +118,10 @@ def nan_to_num( /, *, fill_value: int | float | complex = 0.0, - xp: ModuleType | None = None + xp: ModuleType | None = None, ) -> Array: """ - Replace NaN with zero and infinity with large finite numbers (default - behaviour). + Replace NaN with zero and infinity with large finite numbers (default behaviour). If `x` is inexact, NaN is replaced by zero or by the user defined value in `nan` keyword, infinity is replaced by the largest finite floating point @@ -141,6 +140,8 @@ def nan_to_num( fill_value : int, float, complex, optional Value to be used to fill NaN values. If no value is passed then NaN values will be replaced with 0.0. + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer. Returns ------- diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index f4a51569..dc5b1aed 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -738,7 +738,7 @@ def kron( return xp.reshape(result, res_shape) -def nan_to_num( +def nan_to_num( # numpydoc ignore=PR01,RT01 x: Array, /, fill_value: int | float | complex = 0.0, @@ -746,9 +746,8 @@ def nan_to_num( xp: ModuleType, ) -> Array: """See docstring in `array_api_extra._delegation.py`.""" - xp = array_namespace(x) if xp is None else xp - def perform_replacements( + def perform_replacements( # numpydoc ignore=PR01,RT01 x: Array, fill_value: int | float | complex, xp: ModuleType, diff --git a/tests/conftest.py b/tests/conftest.py index c87e6104..df703b97 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -233,9 +233,10 @@ def device( return xp.device("cpu") return get_device(xp.empty(0)) + @pytest.fixture def infinity(library: Backend) -> float: """Retrieve the positive infinity value for the given backend.""" if library in (Backend.TORCH, Backend.TORCH_GPU): - return 3.4028235e+38 - return 1.7976931348623157e+308 + return 3.4028235e38 + return 1.7976931348623157e308 diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 2c875056..34fbdeec 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -985,13 +985,7 @@ def test_complex(self, xp: ModuleType, infinity: float) -> None: ) xp_assert_equal( nan_to_num(a), - xp.asarray( - [ - infinity + 0j, - 0 + 0j, - 0 + 1j * infinity - ] - ), + xp.asarray([infinity + 0j, 0 + 0j, 0 + 1j * infinity]), ) From 9053822722701dadcba25e437d5b8804accbe48d Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Fri, 22 Aug 2025 16:35:27 +0100 Subject: [PATCH 05/21] Change to `.asarray` Co-authored-by: Lucas Colley --- src/array_api_extra/_delegation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 6580d2c9..b32db6c3 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -166,7 +166,7 @@ def nan_to_num( >>> xpx.nan_to_num(x) array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary -1.28000000e+002, 1.28000000e+002]) - >>> y = xp.array([complex(xp.inf, xp.nan), xp.nan, complex(xp.nan, xp.inf)]) + >>> y = xp.asarray([complex(xp.inf, xp.nan), xp.nan, complex(xp.nan, xp.inf)]) array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary -1.28000000e+002, 1.28000000e+002]) >>> xpx.nan_to_num(y) From 8cce9f2619be6f9946860b53c4654d0b2d2af6ad Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Fri, 22 Aug 2025 16:35:38 +0100 Subject: [PATCH 06/21] Change to `.asarray` Co-authored-by: Lucas Colley --- src/array_api_extra/_delegation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index b32db6c3..baf83921 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -162,7 +162,7 @@ def nan_to_num( -1.7976931348623157e+308 >>> xpx.nan_to_num(xp.nan) 0.0 - >>> x = xp.array([xp.inf, -xp.inf, xp.nan, -128, 128]) + >>> x = xp.asarray([xp.inf, -xp.inf, xp.nan, -128, 128]) >>> xpx.nan_to_num(x) array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary -1.28000000e+002, 1.28000000e+002]) From b5f7e0265add72b3105d32b4b31557f40683e008 Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Fri, 22 Aug 2025 17:44:04 +0200 Subject: [PATCH 07/21] Some code reviewes --- src/array_api_extra/_delegation.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index baf83921..69fb56f7 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -124,9 +124,9 @@ def nan_to_num( Replace NaN with zero and infinity with large finite numbers (default behaviour). If `x` is inexact, NaN is replaced by zero or by the user defined value in - `nan` keyword, infinity is replaced by the largest finite floating point - values representable by ``x.dtype`` and -infinity is replaced by the most - negative finite floating point values representable by ``x.dtype``. + `fill_value` keyword, infinity is replaced by the largest finite floating + point values representable by ``x.dtype`` and -infinity is replaced by the + most negative finite floating point values representable by ``x.dtype``. For complex dtypes, the above is applied to each of the real and imaginary components of `x` separately. @@ -135,9 +135,9 @@ def nan_to_num( Parameters ---------- - x : array, float, complex + x : array | float | complex Input data. - fill_value : int, float, complex, optional + fill_value : int | float | complex, optional Value to be used to fill NaN values. If no value is passed then NaN values will be replaced with 0.0. xp : array_namespace, optional From 0d343f0d17cd789948fe371d7d0175601a0b3501 Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Fri, 22 Aug 2025 17:46:42 +0200 Subject: [PATCH 08/21] Remove ambiguous statement --- src/array_api_extra/_delegation.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 69fb56f7..0be12ab4 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -131,8 +131,6 @@ def nan_to_num( For complex dtypes, the above is applied to each of the real and imaginary components of `x` separately. - If `x` is not inexact, then no replacements are made. - Parameters ---------- x : array | float | complex From 7fd404b80053be695088a1b207a4c50b09b30d46 Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Tue, 26 Aug 2025 15:45:42 +0100 Subject: [PATCH 09/21] Rename class --- tests/test_funcs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 34fbdeec..2554b31c 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -943,7 +943,7 @@ def test_xp(self, xp: ModuleType): xp_assert_equal(kron(a, b, xp=xp), k) -class TestNumToNan: +class TestNanToNum: def test_bool(self, xp: ModuleType) -> None: a = xp.asarray([True]) xp_assert_equal(nan_to_num(a, xp=xp), a) From 5921690267ff1ecf1ff8a3c851a75a559278ef5c Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Tue, 26 Aug 2025 15:45:48 +0100 Subject: [PATCH 10/21] Fill value test --- tests/test_funcs.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 2554b31c..70509efc 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -988,6 +988,13 @@ def test_complex(self, xp: ModuleType, infinity: float) -> None: xp.asarray([infinity + 0j, 0 + 0j, 0 + 1j * infinity]), ) + def test_fill_value(self, xp: ModuleType) -> None: + a = xp.asarray([1, 2, np.nan, 4]) + xp_assert_equal( + nan_to_num(a, fill_value=3, xp=xp), + xp.asarray([1.0, 2.0, 3.0, 4.0]), + ) + class TestNUnique: def test_simple(self, xp: ModuleType): From ce501f5be71f418be03f0e5e0e1678c07ea927b8 Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Tue, 26 Aug 2025 15:58:59 +0100 Subject: [PATCH 11/21] Add empty array test --- tests/test_funcs.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 70509efc..7a860f43 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -995,6 +995,11 @@ def test_fill_value(self, xp: ModuleType) -> None: xp.asarray([1.0, 2.0, 3.0, 4.0]), ) + def test_empty_array(self, xp: ModuleType) -> None: + a = xp.asarray([], dtype=xp.float32) + xp_assert_equal(nan_to_num(a, xp=xp), a) + assert xp.isdtype(nan_to_num(a, xp=xp).dtype, xp.float32) + class TestNUnique: def test_simple(self, xp: ModuleType): From 5c14ce65e1a750175e593112ca8cd3ab7d12f6fd Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Tue, 26 Aug 2025 17:06:31 +0100 Subject: [PATCH 12/21] Add testing comment --- tests/test_funcs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 7a860f43..dcbfbd3a 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -996,7 +996,7 @@ def test_fill_value(self, xp: ModuleType) -> None: ) def test_empty_array(self, xp: ModuleType) -> None: - a = xp.asarray([], dtype=xp.float32) + a = xp.asarray([], dtype=xp.float32) # forced dtype due to torch xp_assert_equal(nan_to_num(a, xp=xp), a) assert xp.isdtype(nan_to_num(a, xp=xp).dtype, xp.float32) From 6c1cf076253c6278e1fdd560f4091f6f3823c61d Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Tue, 26 Aug 2025 17:29:56 +0100 Subject: [PATCH 13/21] Handle case of complex fill val --- src/array_api_extra/_lib/_funcs.py | 6 +++--- tests/test_funcs.py | 20 +++++++++++++++++--- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index dc5b1aed..6af54e8a 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -762,14 +762,14 @@ def perform_replacements( # numpydoc ignore=PR01,RT01 x = xp.where(idx_posinf, finfo.max, x) return xp.where(idx_neginf, finfo.min, x) - if xp.isdtype(x.dtype, "complex floating"): + if isinstance(fill_value, complex) or xp.isdtype(x.dtype, "complex floating"): return perform_replacements( xp.real(x), - fill_value, + fill_value.real, xp, ) + 1j * perform_replacements( xp.imag(x), - fill_value, + fill_value.imag, xp, ) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index dcbfbd3a..bba217ef 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -988,11 +988,25 @@ def test_complex(self, xp: ModuleType, infinity: float) -> None: xp.asarray([infinity + 0j, 0 + 0j, 0 + 1j * infinity]), ) - def test_fill_value(self, xp: ModuleType) -> None: + @pytest.mark.parametrize("fill_value", [3, 3.0, 3 + 0j]) + @pytest.mark.parametrize( + "output", + [ + [1.0, 2.0, 3.0, 4.0], + [1.0, 2.0, 3.0, 4.0], + [1.0 + 0.j, 2.0 + 0.j, 3.0 + 0.j, 4.0 + 0.j] + ], + ) + def test_fill_value( + self, + xp: ModuleType, + fill_value: float, + output: float, + ) -> None: a = xp.asarray([1, 2, np.nan, 4]) xp_assert_equal( - nan_to_num(a, fill_value=3, xp=xp), - xp.asarray([1.0, 2.0, 3.0, 4.0]), + nan_to_num(a, fill_value=fill_value, xp=xp), + xp.asarray(output), ) def test_empty_array(self, xp: ModuleType) -> None: From 213c9541ee15562f1cc4272a4fcd1e76bea49a1c Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Wed, 27 Aug 2025 12:05:36 +0100 Subject: [PATCH 14/21] Be more careful about complex fill values --- src/array_api_extra/_delegation.py | 8 ++++++-- src/array_api_extra/_lib/_funcs.py | 10 +++++----- tests/test_funcs.py | 30 ++++++++++++++++++++++-------- 3 files changed, 33 insertions(+), 15 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 0be12ab4..2316c7e1 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -117,7 +117,7 @@ def nan_to_num( x: Array | float | complex, /, *, - fill_value: int | float | complex = 0.0, + fill_value: int | float = 0.0, xp: ModuleType | None = None, ) -> Array: """ @@ -135,7 +135,7 @@ def nan_to_num( ---------- x : array | float | complex Input data. - fill_value : int | float | complex, optional + fill_value : int | float, optional Value to be used to fill NaN values. If no value is passed then NaN values will be replaced with 0.0. xp : array_namespace, optional @@ -172,6 +172,10 @@ def nan_to_num( 0.00000000e+000 +0.00000000e+000j, 0.00000000e+000 +1.79769313e+308j]) """ + if isinstance(fill_value, complex): + msg = "Cannot cast scalar from complex dtype to float dtype." + raise TypeError(msg) + xp = array_namespace(x) if xp is None else xp # for scalars we want to output an array diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 6af54e8a..cbcbe0ff 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -741,7 +741,7 @@ def kron( def nan_to_num( # numpydoc ignore=PR01,RT01 x: Array, /, - fill_value: int | float | complex = 0.0, + fill_value: int | float = 0.0, *, xp: ModuleType, ) -> Array: @@ -749,7 +749,7 @@ def nan_to_num( # numpydoc ignore=PR01,RT01 def perform_replacements( # numpydoc ignore=PR01,RT01 x: Array, - fill_value: int | float | complex, + fill_value: int | float, xp: ModuleType, ) -> Array: """Internal function to perform the replacements.""" @@ -762,14 +762,14 @@ def perform_replacements( # numpydoc ignore=PR01,RT01 x = xp.where(idx_posinf, finfo.max, x) return xp.where(idx_neginf, finfo.min, x) - if isinstance(fill_value, complex) or xp.isdtype(x.dtype, "complex floating"): + if xp.isdtype(x.dtype, "complex floating"): return perform_replacements( xp.real(x), - fill_value.real, + fill_value, xp, ) + 1j * perform_replacements( xp.imag(x), - fill_value.imag, + fill_value, xp, ) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index bba217ef..6d503542 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -988,25 +988,39 @@ def test_complex(self, xp: ModuleType, infinity: float) -> None: xp.asarray([infinity + 0j, 0 + 0j, 0 + 1j * infinity]), ) - @pytest.mark.parametrize("fill_value", [3, 3.0, 3 + 0j]) @pytest.mark.parametrize( - "output", + "in_vals,fill_value,out_vals", [ - [1.0, 2.0, 3.0, 4.0], - [1.0, 2.0, 3.0, 4.0], - [1.0 + 0.j, 2.0 + 0.j, 3.0 + 0.j, 4.0 + 0.j] + ([1, 2, np.nan, 4], 3, [1.0, 2.0, 3.0, 4.0]), + ([1, 2, np.nan, 4], 3.0, [1.0, 2.0, 3.0, 4.0]), + ( + [1 + 1j, 2 + 2j, np.nan + 0j, 4 + 4j], + 3, + [1.0 + 1.0j, 2.0 + 2.0j, 3.0 + 0.0j, 4.0 + 4.0j], + ), + ( + [1 + 1j, 2 + 2j, 0 + 1j * np.nan, 4 + 4j], + 3.0, + [1.0 + 1.0j, 2.0 + 2.0j, 3.0 + 3.0j, 4.0 + 4.0j], + ), + ( + [1 + 1j, 2 + 2j, np.nan + 1j * np.nan, 4 + 4j], + 3.0, + [1.0 + 1.0j, 2.0 + 2.0j, 3.0 + 3.0j, 4.0 + 4.0j], + ), ], ) def test_fill_value( self, xp: ModuleType, + in_vals: Array, fill_value: float, - output: float, + out_vals: float, ) -> None: - a = xp.asarray([1, 2, np.nan, 4]) + a = xp.asarray(in_vals) xp_assert_equal( nan_to_num(a, fill_value=fill_value, xp=xp), - xp.asarray(output), + xp.asarray(out_vals), ) def test_empty_array(self, xp: ModuleType) -> None: From 51b28dbca75c42b1cd5305cebcab7e2bfbd9fc6f Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Wed, 27 Aug 2025 12:09:53 +0100 Subject: [PATCH 15/21] Use complex function instead --- tests/test_funcs.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 6d503542..5d74c241 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -994,17 +994,17 @@ def test_complex(self, xp: ModuleType, infinity: float) -> None: ([1, 2, np.nan, 4], 3, [1.0, 2.0, 3.0, 4.0]), ([1, 2, np.nan, 4], 3.0, [1.0, 2.0, 3.0, 4.0]), ( - [1 + 1j, 2 + 2j, np.nan + 0j, 4 + 4j], + [1 + 1j, 2 + 2j, complex(np.nan, 0), 4 + 4j], 3, [1.0 + 1.0j, 2.0 + 2.0j, 3.0 + 0.0j, 4.0 + 4.0j], ), ( - [1 + 1j, 2 + 2j, 0 + 1j * np.nan, 4 + 4j], + [1 + 1j, 2 + 2j, complex(0, np.nan), 4 + 4j], 3.0, - [1.0 + 1.0j, 2.0 + 2.0j, 3.0 + 3.0j, 4.0 + 4.0j], + [1.0 + 1.0j, 2.0 + 2.0j, 0.0 + 3.0j, 4.0 + 4.0j], ), ( - [1 + 1j, 2 + 2j, np.nan + 1j * np.nan, 4 + 4j], + [1 + 1j, 2 + 2j, complex(np.nan, np.nan), 4 + 4j], 3.0, [1.0 + 1.0j, 2.0 + 2.0j, 3.0 + 3.0j, 4.0 + 4.0j], ), From 1286ae712366aab944c4b2d794e2c3c028a23409 Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Wed, 27 Aug 2025 12:14:23 +0100 Subject: [PATCH 16/21] Use the complex function --- tests/test_funcs.py | 46 +++++++++++++++++++++++++++++++++++++-------- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 5d74c241..21fe68d4 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -985,7 +985,7 @@ def test_complex(self, xp: ModuleType, infinity: float) -> None: ) xp_assert_equal( nan_to_num(a), - xp.asarray([infinity + 0j, 0 + 0j, 0 + 1j * infinity]), + xp.asarray([complex(infinity, 0), complex(0, 0), complex(0, infinity)]), ) @pytest.mark.parametrize( @@ -994,23 +994,53 @@ def test_complex(self, xp: ModuleType, infinity: float) -> None: ([1, 2, np.nan, 4], 3, [1.0, 2.0, 3.0, 4.0]), ([1, 2, np.nan, 4], 3.0, [1.0, 2.0, 3.0, 4.0]), ( - [1 + 1j, 2 + 2j, complex(np.nan, 0), 4 + 4j], + [ + complex(1, 1), + complex(2, 2), + complex(np.nan, 0), + complex(4, 4), + ], 3, - [1.0 + 1.0j, 2.0 + 2.0j, 3.0 + 0.0j, 4.0 + 4.0j], + [ + complex(1.0, 1.0), + complex(2.0, 2.0), + complex(3.0, 0.0), + complex(4.0, 4.0), + ], ), ( - [1 + 1j, 2 + 2j, complex(0, np.nan), 4 + 4j], + [ + complex(1, 1), + complex(2, 2), + complex(0, np.nan), + complex(4, 4), + ], 3.0, - [1.0 + 1.0j, 2.0 + 2.0j, 0.0 + 3.0j, 4.0 + 4.0j], + [ + complex(1.0, 1.0), + complex(2.0, 2.0), + complex(0.0, 3.0), + complex(4.0, 4.0), + ], ), ( - [1 + 1j, 2 + 2j, complex(np.nan, np.nan), 4 + 4j], + [ + complex(1, 1), + complex(2, 2), + complex(np.nan, np.nan), + complex(4, 4), + ], 3.0, - [1.0 + 1.0j, 2.0 + 2.0j, 3.0 + 3.0j, 4.0 + 4.0j], + [ + complex(1.0, 1.0), + complex(2.0, 2.0), + complex(3.0, 3.0), + complex(4.0, 4.0), + ], ), ], ) - def test_fill_value( + def test_fill_value_success( self, xp: ModuleType, in_vals: Array, From 0e2ae837ea6d379ab266fc0fd498847914e4619b Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Wed, 27 Aug 2025 12:22:55 +0100 Subject: [PATCH 17/21] Add known failure test --- src/array_api_extra/_delegation.py | 2 +- tests/test_funcs.py | 26 ++++++++++++++++++++------ 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 2316c7e1..88b517d0 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -173,7 +173,7 @@ def nan_to_num( 0.00000000e+000 +1.79769313e+308j]) """ if isinstance(fill_value, complex): - msg = "Cannot cast scalar from complex dtype to float dtype." + msg = "Complex fill values are not supported." raise TypeError(msg) xp = array_namespace(x) if xp is None else xp diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 21fe68d4..acf32a22 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -988,6 +988,11 @@ def test_complex(self, xp: ModuleType, infinity: float) -> None: xp.asarray([complex(infinity, 0), complex(0, 0), complex(0, infinity)]), ) + def test_empty_array(self, xp: ModuleType) -> None: + a = xp.asarray([], dtype=xp.float32) # forced dtype due to torch + xp_assert_equal(nan_to_num(a, xp=xp), a) + assert xp.isdtype(nan_to_num(a, xp=xp).dtype, xp.float32) + @pytest.mark.parametrize( "in_vals,fill_value,out_vals", [ @@ -1044,8 +1049,8 @@ def test_fill_value_success( self, xp: ModuleType, in_vals: Array, - fill_value: float, - out_vals: float, + fill_value: int | float, + out_vals: Array, ) -> None: a = xp.asarray(in_vals) xp_assert_equal( @@ -1053,10 +1058,19 @@ def test_fill_value_success( xp.asarray(out_vals), ) - def test_empty_array(self, xp: ModuleType) -> None: - a = xp.asarray([], dtype=xp.float32) # forced dtype due to torch - xp_assert_equal(nan_to_num(a, xp=xp), a) - assert xp.isdtype(nan_to_num(a, xp=xp).dtype, xp.float32) + def test_fill_value_failure(self, xp: ModuleType) -> None: + a = xp.asarray( + [ + complex(1, 1), + complex(xp.nan, xp.nan), + complex(3, 3), + ] + ) + with pytest.raises( + TypeError, + match="Complex fill values are not supported", + ): + nan_to_num(a, fill_value=complex(2, 2), xp=xp) class TestNUnique: From cb76ad9ba8e4bfa24804b55fd7a7c719a1037eb0 Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Wed, 27 Aug 2025 12:29:25 +0100 Subject: [PATCH 18/21] Fix linting --- tests/test_funcs.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index acf32a22..8304b848 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -994,7 +994,7 @@ def test_empty_array(self, xp: ModuleType) -> None: assert xp.isdtype(nan_to_num(a, xp=xp).dtype, xp.float32) @pytest.mark.parametrize( - "in_vals,fill_value,out_vals", + ("in_vals", "fill_value", "out_vals"), [ ([1, 2, np.nan, 4], 3, [1.0, 2.0, 3.0, 4.0]), ([1, 2, np.nan, 4], 3.0, [1.0, 2.0, 3.0, 4.0]), @@ -1070,7 +1070,11 @@ def test_fill_value_failure(self, xp: ModuleType) -> None: TypeError, match="Complex fill values are not supported", ): - nan_to_num(a, fill_value=complex(2, 2), xp=xp) + _ = nan_to_num( + a, + fill_value=complex(2, 2), # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + xp=xp, + ) class TestNUnique: From ad86943360a58c04e23b46c8de0be7b6e635b6c0 Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Wed, 27 Aug 2025 18:05:34 +0100 Subject: [PATCH 19/21] Update src/array_api_extra/_delegation.py Co-authored-by: Lucas Colley --- src/array_api_extra/_delegation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 88b517d0..dd73adef 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -123,7 +123,7 @@ def nan_to_num( """ Replace NaN with zero and infinity with large finite numbers (default behaviour). - If `x` is inexact, NaN is replaced by zero or by the user defined value in + If `x` is inexact, NaN is replaced by zero or by the user defined value in the `fill_value` keyword, infinity is replaced by the largest finite floating point values representable by ``x.dtype`` and -infinity is replaced by the most negative finite floating point values representable by ``x.dtype``. From e449a6b5e0a2d3f656cc05938c08c9fbb0f01bcb Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Wed, 27 Aug 2025 18:05:43 +0100 Subject: [PATCH 20/21] Update src/array_api_extra/_delegation.py Co-authored-by: Lucas Colley --- src/array_api_extra/_delegation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index dd73adef..6a75913a 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -125,7 +125,7 @@ def nan_to_num( If `x` is inexact, NaN is replaced by zero or by the user defined value in the `fill_value` keyword, infinity is replaced by the largest finite floating - point values representable by ``x.dtype`` and -infinity is replaced by the + point value representable by ``x.dtype``, and -infinity is replaced by the most negative finite floating point values representable by ``x.dtype``. For complex dtypes, the above is applied to each of the real and From d51fa05b4f9c833bcc323b964df953b0b7669398 Mon Sep 17 00:00:00 2001 From: "Patrick J. Roddy" Date: Wed, 27 Aug 2025 18:05:51 +0100 Subject: [PATCH 21/21] Update src/array_api_extra/_delegation.py Co-authored-by: Lucas Colley --- src/array_api_extra/_delegation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 6a75913a..2c061e36 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -126,7 +126,7 @@ def nan_to_num( If `x` is inexact, NaN is replaced by zero or by the user defined value in the `fill_value` keyword, infinity is replaced by the largest finite floating point value representable by ``x.dtype``, and -infinity is replaced by the - most negative finite floating point values representable by ``x.dtype``. + most negative finite floating point value representable by ``x.dtype``. For complex dtypes, the above is applied to each of the real and imaginary components of `x` separately.