Skip to content

Commit 708bce7

Browse files
authored
ENH: atleast_nd delegation (#454)
1 parent 355887e commit 708bce7

File tree

4 files changed

+76
-38
lines changed

4 files changed

+76
-38
lines changed

src/array_api_extra/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._delegation import expand_dims, isclose, nan_to_num, one_hot, pad, sinc
3+
from ._delegation import (
4+
atleast_nd,
5+
expand_dims,
6+
isclose,
7+
nan_to_num,
8+
one_hot,
9+
pad,
10+
sinc,
11+
)
412
from ._lib._at import at
513
from ._lib._funcs import (
614
apply_where,
7-
atleast_nd,
815
broadcast_shapes,
916
cov,
1017
create_diagonal,

src/array_api_extra/_delegation.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,55 @@ def expand_dims(
109109
return _funcs.expand_dims(a, axis=axis, xp=xp)
110110

111111

112+
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array:
113+
"""
114+
Recursively expand the dimension of an array to at least `ndim`.
115+
116+
Parameters
117+
----------
118+
x : array
119+
Input array.
120+
ndim : int
121+
The minimum number of dimensions for the result.
122+
xp : array_namespace, optional
123+
The standard-compatible namespace for `x`. Default: infer.
124+
125+
Returns
126+
-------
127+
array
128+
An array with ``res.ndim`` >= `ndim`.
129+
If ``x.ndim`` >= `ndim`, `x` is returned.
130+
If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes
131+
until ``res.ndim`` equals `ndim`.
132+
133+
Examples
134+
--------
135+
>>> import array_api_strict as xp
136+
>>> import array_api_extra as xpx
137+
>>> x = xp.asarray([1])
138+
>>> xpx.atleast_nd(x, ndim=3, xp=xp)
139+
Array([[[1]]], dtype=array_api_strict.int64)
140+
141+
>>> x = xp.asarray([[[1, 2],
142+
... [3, 4]]])
143+
>>> xpx.atleast_nd(x, ndim=1, xp=xp) is x
144+
True
145+
"""
146+
if xp is None:
147+
xp = array_namespace(x)
148+
149+
if 1 <= ndim <= 3 and (
150+
is_numpy_namespace(xp)
151+
or is_jax_namespace(xp)
152+
or is_dask_namespace(xp)
153+
or is_cupy_namespace(xp)
154+
or is_torch_namespace(xp)
155+
):
156+
return getattr(xp, f"atleast_{ndim}d")(x)
157+
158+
return _funcs.atleast_nd(x, ndim=ndim, xp=xp)
159+
160+
112161
def isclose(
113162
a: Array | complex,
114163
b: Array | complex,

src/array_api_extra/_lib/_funcs.py

Lines changed: 3 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -175,42 +175,9 @@ def _apply_where( # numpydoc ignore=PR01,RT01
175175
return at(out, cond).set(temp1)
176176

177177

178-
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array:
179-
"""
180-
Recursively expand the dimension of an array to at least `ndim`.
181-
182-
Parameters
183-
----------
184-
x : array
185-
Input array.
186-
ndim : int
187-
The minimum number of dimensions for the result.
188-
xp : array_namespace, optional
189-
The standard-compatible namespace for `x`. Default: infer.
190-
191-
Returns
192-
-------
193-
array
194-
An array with ``res.ndim`` >= `ndim`.
195-
If ``x.ndim`` >= `ndim`, `x` is returned.
196-
If ``x.ndim`` < `ndim`, `x` is expanded by prepending new axes
197-
until ``res.ndim`` equals `ndim`.
198-
199-
Examples
200-
--------
201-
>>> import array_api_strict as xp
202-
>>> import array_api_extra as xpx
203-
>>> x = xp.asarray([1])
204-
>>> xpx.atleast_nd(x, ndim=3, xp=xp)
205-
Array([[[1]]], dtype=array_api_strict.int64)
206-
207-
>>> x = xp.asarray([[[1, 2],
208-
... [3, 4]]])
209-
>>> xpx.atleast_nd(x, ndim=1, xp=xp) is x
210-
True
211-
"""
212-
if xp is None:
213-
xp = array_namespace(x)
178+
def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array:
179+
# numpydoc ignore=PR01,RT01
180+
"""See docstring in array_api_extra._delegation."""
214181

215182
if x.ndim < ndim:
216183
x = xp.expand_dims(x, axis=0)

tests/test_funcs.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,21 @@ def test_2D(self, xp: ModuleType):
316316
y = atleast_nd(x, ndim=5)
317317
xp_assert_equal(y, 3 * xp.ones((1, 1, 1, 1, 1)))
318318

319+
def test_3D(self, xp: ModuleType):
320+
x = xp.asarray([[[3.0], [2.0]]])
321+
322+
y = atleast_nd(x, ndim=0)
323+
xp_assert_equal(y, x)
324+
325+
y = atleast_nd(x, ndim=2)
326+
xp_assert_equal(y, x)
327+
328+
y = atleast_nd(x, ndim=3)
329+
xp_assert_equal(y, x)
330+
331+
y = atleast_nd(x, ndim=5)
332+
xp_assert_equal(y, xp.asarray([[[[[3.0], [2.0]]]]]))
333+
319334
def test_5D(self, xp: ModuleType):
320335
x = xp.ones((1, 1, 1, 1, 1))
321336

0 commit comments

Comments
 (0)