From 2187ece611bce349093c6bc5443bf938b577c359 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Gauthier-Clerc?= Date: Thu, 2 Oct 2025 16:14:31 +0200 Subject: [PATCH] ENH: atleast_nd delegation --- src/array_api_extra/__init__.py | 11 +++++-- src/array_api_extra/_delegation.py | 49 ++++++++++++++++++++++++++++++ src/array_api_extra/_lib/_funcs.py | 39 ++---------------------- tests/test_funcs.py | 15 +++++++++ 4 files changed, 76 insertions(+), 38 deletions(-) diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 54fd4ba2..dd37953e 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -1,10 +1,17 @@ """Extra array functions built on top of the array API standard.""" -from ._delegation import expand_dims, isclose, nan_to_num, one_hot, pad, sinc +from ._delegation import ( + atleast_nd, + expand_dims, + isclose, + nan_to_num, + one_hot, + pad, + sinc, +) from ._lib._at import at from ._lib._funcs import ( apply_where, - atleast_nd, broadcast_shapes, cov, create_diagonal, diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 26d8d0cd..d9e45838 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -109,6 +109,55 @@ def expand_dims( return _funcs.expand_dims(a, axis=axis, xp=xp) +def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array: + """ + Recursively expand the dimension of an array to at least `ndim`. + + Parameters + ---------- + x : array + Input array. + ndim : int + The minimum number of dimensions for the result. + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer. + + Returns + ------- + array + An array with ``res.ndim`` >= `ndim`. + If ``x.ndim`` >= `ndim`, `x` is returned. + If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes + until ``res.ndim`` equals `ndim`. + + Examples + -------- + >>> import array_api_strict as xp + >>> import array_api_extra as xpx + >>> x = xp.asarray([1]) + >>> xpx.atleast_nd(x, ndim=3, xp=xp) + Array([[[1]]], dtype=array_api_strict.int64) + + >>> x = xp.asarray([[[1, 2], + ... [3, 4]]]) + >>> xpx.atleast_nd(x, ndim=1, xp=xp) is x + True + """ + if xp is None: + xp = array_namespace(x) + + if 1 <= ndim <= 3 and ( + is_numpy_namespace(xp) + or is_jax_namespace(xp) + or is_dask_namespace(xp) + or is_cupy_namespace(xp) + or is_torch_namespace(xp) + ): + return getattr(xp, f"atleast_{ndim}d")(x) + + return _funcs.atleast_nd(x, ndim=ndim, xp=xp) + + def isclose( a: Array | complex, b: Array | complex, diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index b25e6e3e..88703ecc 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -175,42 +175,9 @@ def _apply_where( # numpydoc ignore=PR01,RT01 return at(out, cond).set(temp1) -def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array: - """ - Recursively expand the dimension of an array to at least `ndim`. - - Parameters - ---------- - x : array - Input array. - ndim : int - The minimum number of dimensions for the result. - xp : array_namespace, optional - The standard-compatible namespace for `x`. Default: infer. - - Returns - ------- - array - An array with ``res.ndim`` >= `ndim`. - If ``x.ndim`` >= `ndim`, `x` is returned. - If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes - until ``res.ndim`` equals `ndim`. - - Examples - -------- - >>> import array_api_strict as xp - >>> import array_api_extra as xpx - >>> x = xp.asarray([1]) - >>> xpx.atleast_nd(x, ndim=3, xp=xp) - Array([[[1]]], dtype=array_api_strict.int64) - - >>> x = xp.asarray([[[1, 2], - ... [3, 4]]]) - >>> xpx.atleast_nd(x, ndim=1, xp=xp) is x - True - """ - if xp is None: - xp = array_namespace(x) +def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array: + # numpydoc ignore=PR01,RT01 + """See docstring in array_api_extra._delegation.""" if x.ndim < ndim: x = xp.expand_dims(x, axis=0) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 90813ecb..7da01591 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -316,6 +316,21 @@ def test_2D(self, xp: ModuleType): y = atleast_nd(x, ndim=5) xp_assert_equal(y, 3 * xp.ones((1, 1, 1, 1, 1))) + def test_3D(self, xp: ModuleType): + x = xp.asarray([[[3.0], [2.0]]]) + + y = atleast_nd(x, ndim=0) + xp_assert_equal(y, x) + + y = atleast_nd(x, ndim=2) + xp_assert_equal(y, x) + + y = atleast_nd(x, ndim=3) + xp_assert_equal(y, x) + + y = atleast_nd(x, ndim=5) + xp_assert_equal(y, xp.asarray([[[[[3.0], [2.0]]]]])) + def test_5D(self, xp: ModuleType): x = xp.ones((1, 1, 1, 1, 1))