Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
"""Extra array functions built on top of the array API standard."""

from ._delegation import expand_dims, isclose, nan_to_num, one_hot, pad, sinc
from ._delegation import (
atleast_nd,
expand_dims,
isclose,
nan_to_num,
one_hot,
pad,
sinc,
)
from ._lib._at import at
from ._lib._funcs import (
apply_where,
atleast_nd,
broadcast_shapes,
cov,
create_diagonal,
Expand Down
49 changes: 49 additions & 0 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,55 @@ def expand_dims(
return _funcs.expand_dims(a, axis=axis, xp=xp)


def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array:
"""
Recursively expand the dimension of an array to at least `ndim`.

Parameters
----------
x : array
Input array.
ndim : int
The minimum number of dimensions for the result.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.

Returns
-------
array
An array with ``res.ndim`` >= `ndim`.
If ``x.ndim`` >= `ndim`, `x` is returned.
If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes
until ``res.ndim`` equals `ndim`.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> x = xp.asarray([1])
>>> xpx.atleast_nd(x, ndim=3, xp=xp)
Array([[[1]]], dtype=array_api_strict.int64)

>>> x = xp.asarray([[[1, 2],
... [3, 4]]])
>>> xpx.atleast_nd(x, ndim=1, xp=xp) is x
True
"""
if xp is None:
xp = array_namespace(x)

if 1 <= ndim <= 3 and (
is_numpy_namespace(xp)
or is_jax_namespace(xp)
or is_dask_namespace(xp)
or is_cupy_namespace(xp)
or is_torch_namespace(xp)
):
return getattr(xp, f"atleast_{ndim}d")(x)

return _funcs.atleast_nd(x, ndim=ndim, xp=xp)


def isclose(
a: Array | complex,
b: Array | complex,
Expand Down
39 changes: 3 additions & 36 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,42 +175,9 @@ def _apply_where( # numpydoc ignore=PR01,RT01
return at(out, cond).set(temp1)


def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array:
"""
Recursively expand the dimension of an array to at least `ndim`.
Parameters
----------
x : array
Input array.
ndim : int
The minimum number of dimensions for the result.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.
Returns
-------
array
An array with ``res.ndim`` >= `ndim`.
If ``x.ndim`` >= `ndim`, `x` is returned.
If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes
until ``res.ndim`` equals `ndim`.
Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> x = xp.asarray([1])
>>> xpx.atleast_nd(x, ndim=3, xp=xp)
Array([[[1]]], dtype=array_api_strict.int64)
>>> x = xp.asarray([[[1, 2],
... [3, 4]]])
>>> xpx.atleast_nd(x, ndim=1, xp=xp) is x
True
"""
if xp is None:
xp = array_namespace(x)
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
# numpydoc ignore=PR01,RT01
"""See docstring in array_api_extra._delegation."""

if x.ndim < ndim:
x = xp.expand_dims(x, axis=0)
Expand Down
15 changes: 15 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,21 @@ def test_2D(self, xp: ModuleType):
y = atleast_nd(x, ndim=5)
xp_assert_equal(y, 3 * xp.ones((1, 1, 1, 1, 1)))

def test_3D(self, xp: ModuleType):
x = xp.asarray([[[3.0], [2.0]]])

y = atleast_nd(x, ndim=0)
xp_assert_equal(y, x)

y = atleast_nd(x, ndim=2)
xp_assert_equal(y, x)

y = atleast_nd(x, ndim=3)
xp_assert_equal(y, x)

y = atleast_nd(x, ndim=5)
xp_assert_equal(y, xp.asarray([[[[[3.0], [2.0]]]]]))

def test_5D(self, xp: ModuleType):
x = xp.ones((1, 1, 1, 1, 1))

Expand Down