Skip to content

Commit c9fd5ba

Browse files
committed
ENH: expand_dims
1 parent ca20f03 commit c9fd5ba

File tree

3 files changed

+93
-84
lines changed

3 files changed

+93
-84
lines changed

src/array_api_extra/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._delegation import isclose, nan_to_num, one_hot, pad
3+
from ._delegation import expand_dims, isclose, nan_to_num, one_hot, pad
44
from ._lib._at import at
55
from ._lib._funcs import (
66
apply_where,
@@ -9,7 +9,6 @@
99
cov,
1010
create_diagonal,
1111
default_dtype,
12-
expand_dims,
1312
kron,
1413
nunique,
1514
setdiff1d,

src/array_api_extra/_delegation.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,95 @@
1818
from ._lib._utils._helpers import asarrays
1919
from ._lib._utils._typing import Array, DType
2020

21-
__all__ = ["isclose", "nan_to_num", "one_hot", "pad"]
21+
__all__ = ["expand_dims", "isclose", "nan_to_num", "one_hot", "pad"]
22+
23+
24+
def expand_dims(
25+
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None
26+
) -> Array:
27+
"""
28+
Expand the shape of an array.
29+
30+
Insert (a) new axis/axes that will appear at the position(s) specified by
31+
`axis` in the expanded array shape.
32+
33+
This is ``xp.expand_dims`` for `axis` an int *or a tuple of ints*.
34+
Roughly equivalent to ``numpy.expand_dims`` for NumPy arrays.
35+
36+
Parameters
37+
----------
38+
a : array
39+
Array to have its shape expanded.
40+
axis : int or tuple of ints, optional
41+
Position(s) in the expanded axes where the new axis (or axes) is/are placed.
42+
If multiple positions are provided, they should be unique (note that a position
43+
given by a positive index could also be referred to by a negative index -
44+
that will also result in an error).
45+
Default: ``(0,)``.
46+
xp : array_namespace, optional
47+
The standard-compatible namespace for `a`. Default: infer.
48+
49+
Returns
50+
-------
51+
array
52+
`a` with an expanded shape.
53+
54+
Examples
55+
--------
56+
>>> import array_api_strict as xp
57+
>>> import array_api_extra as xpx
58+
>>> x = xp.asarray([1, 2])
59+
>>> x.shape
60+
(2,)
61+
62+
The following is equivalent to ``x[xp.newaxis, :]`` or ``x[xp.newaxis]``:
63+
64+
>>> y = xpx.expand_dims(x, axis=0, xp=xp)
65+
>>> y
66+
Array([[1, 2]], dtype=array_api_strict.int64)
67+
>>> y.shape
68+
(1, 2)
69+
70+
The following is equivalent to ``x[:, xp.newaxis]``:
71+
72+
>>> y = xpx.expand_dims(x, axis=1, xp=xp)
73+
>>> y
74+
Array([[1],
75+
[2]], dtype=array_api_strict.int64)
76+
>>> y.shape
77+
(2, 1)
78+
79+
``axis`` may also be a tuple:
80+
81+
>>> y = xpx.expand_dims(x, axis=(0, 1), xp=xp)
82+
>>> y
83+
Array([[[1, 2]]], dtype=array_api_strict.int64)
84+
85+
>>> y = xpx.expand_dims(x, axis=(2, 0), xp=xp)
86+
>>> y
87+
Array([[[1],
88+
[2]]], dtype=array_api_strict.int64)
89+
"""
90+
if xp is None:
91+
xp = array_namespace(a)
92+
93+
if not isinstance(axis, tuple):
94+
axis = (axis,)
95+
ndim = a.ndim + len(axis)
96+
if axis != () and (min(axis) < -ndim or max(axis) >= ndim):
97+
err_msg = (
98+
f"a provided axis position is out of bounds for array of dimension {a.ndim}"
99+
)
100+
raise IndexError(err_msg)
101+
axis = tuple(dim % ndim for dim in axis)
102+
if len(set(axis)) != len(axis):
103+
err_msg = "Duplicate dimensions specified in `axis`."
104+
raise ValueError(err_msg)
105+
106+
if is_numpy_namespace(xp) or is_dask_namespace(xp) or is_jax_namespace(xp):
107+
return xp.expand_dims(a, axis=axis)
108+
109+
return _funcs.expand_dims(a, axis=axis, xp=xp)
22110

23111

24112
def isclose(

src/array_api_extra/_lib/_funcs.py

Lines changed: 3 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -493,87 +493,9 @@ def default_dtype(
493493
raise ValueError(msg) from e
494494

495495

496-
def expand_dims(
497-
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None
498-
) -> Array:
499-
"""
500-
Expand the shape of an array.
501-
502-
Insert (a) new axis/axes that will appear at the position(s) specified by
503-
`axis` in the expanded array shape.
504-
505-
This is ``xp.expand_dims`` for `axis` an int *or a tuple of ints*.
506-
Roughly equivalent to ``numpy.expand_dims`` for NumPy arrays.
507-
508-
Parameters
509-
----------
510-
a : array
511-
Array to have its shape expanded.
512-
axis : int or tuple of ints, optional
513-
Position(s) in the expanded axes where the new axis (or axes) is/are placed.
514-
If multiple positions are provided, they should be unique (note that a position
515-
given by a positive index could also be referred to by a negative index -
516-
that will also result in an error).
517-
Default: ``(0,)``.
518-
xp : array_namespace, optional
519-
The standard-compatible namespace for `a`. Default: infer.
520-
521-
Returns
522-
-------
523-
array
524-
`a` with an expanded shape.
525-
526-
Examples
527-
--------
528-
>>> import array_api_strict as xp
529-
>>> import array_api_extra as xpx
530-
>>> x = xp.asarray([1, 2])
531-
>>> x.shape
532-
(2,)
533-
534-
The following is equivalent to ``x[xp.newaxis, :]`` or ``x[xp.newaxis]``:
535-
536-
>>> y = xpx.expand_dims(x, axis=0, xp=xp)
537-
>>> y
538-
Array([[1, 2]], dtype=array_api_strict.int64)
539-
>>> y.shape
540-
(1, 2)
541-
542-
The following is equivalent to ``x[:, xp.newaxis]``:
543-
544-
>>> y = xpx.expand_dims(x, axis=1, xp=xp)
545-
>>> y
546-
Array([[1],
547-
[2]], dtype=array_api_strict.int64)
548-
>>> y.shape
549-
(2, 1)
550-
551-
``axis`` may also be a tuple:
552-
553-
>>> y = xpx.expand_dims(x, axis=(0, 1), xp=xp)
554-
>>> y
555-
Array([[[1, 2]]], dtype=array_api_strict.int64)
556-
557-
>>> y = xpx.expand_dims(x, axis=(2, 0), xp=xp)
558-
>>> y
559-
Array([[[1],
560-
[2]]], dtype=array_api_strict.int64)
561-
"""
562-
if xp is None:
563-
xp = array_namespace(a)
564-
565-
if not isinstance(axis, tuple):
566-
axis = (axis,)
567-
ndim = a.ndim + len(axis)
568-
if axis != () and (min(axis) < -ndim or max(axis) >= ndim):
569-
err_msg = (
570-
f"a provided axis position is out of bounds for array of dimension {a.ndim}"
571-
)
572-
raise IndexError(err_msg)
573-
axis = tuple(dim % ndim for dim in axis)
574-
if len(set(axis)) != len(axis):
575-
err_msg = "Duplicate dimensions specified in `axis`."
576-
raise ValueError(err_msg)
496+
def expand_dims(a: Array, /, *, axis: tuple[int, ...] = (0,), xp: ModuleType) -> Array:
497+
# numpydoc ignore=PR01,RT01
498+
"""See docstring in array_api_extra._delegation."""
577499
for i in sorted(axis):
578500
a = xp.expand_dims(a, axis=i)
579501
return a

0 commit comments

Comments
 (0)