diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 0b4bd5e2..ee8a593f 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -1,6 +1,6 @@ """Extra array functions built on top of the array API standard.""" -from ._delegation import isclose, nan_to_num, one_hot, pad +from ._delegation import expand_dims, isclose, nan_to_num, one_hot, pad from ._lib._at import at from ._lib._funcs import ( apply_where, @@ -9,7 +9,6 @@ cov, create_diagonal, default_dtype, - expand_dims, kron, nunique, setdiff1d, diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 2c061e36..a2d8cf8d 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -18,7 +18,95 @@ from ._lib._utils._helpers import asarrays from ._lib._utils._typing import Array, DType -__all__ = ["isclose", "nan_to_num", "one_hot", "pad"] +__all__ = ["expand_dims", "isclose", "nan_to_num", "one_hot", "pad"] + + +def expand_dims( + a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None +) -> Array: + """ + Expand the shape of an array. + + Insert (a) new axis/axes that will appear at the position(s) specified by + `axis` in the expanded array shape. + + This is ``xp.expand_dims`` for `axis` an int *or a tuple of ints*. + Roughly equivalent to ``numpy.expand_dims`` for NumPy arrays. + + Parameters + ---------- + a : array + Array to have its shape expanded. + axis : int or tuple of ints, optional + Position(s) in the expanded axes where the new axis (or axes) is/are placed. + If multiple positions are provided, they should be unique (note that a position + given by a positive index could also be referred to by a negative index - + that will also result in an error). + Default: ``(0,)``. + xp : array_namespace, optional + The standard-compatible namespace for `a`. Default: infer. + + Returns + ------- + array + `a` with an expanded shape. + + Examples + -------- + >>> import array_api_strict as xp + >>> import array_api_extra as xpx + >>> x = xp.asarray([1, 2]) + >>> x.shape + (2,) + + The following is equivalent to ``x[xp.newaxis, :]`` or ``x[xp.newaxis]``: + + >>> y = xpx.expand_dims(x, axis=0, xp=xp) + >>> y + Array([[1, 2]], dtype=array_api_strict.int64) + >>> y.shape + (1, 2) + + The following is equivalent to ``x[:, xp.newaxis]``: + + >>> y = xpx.expand_dims(x, axis=1, xp=xp) + >>> y + Array([[1], + [2]], dtype=array_api_strict.int64) + >>> y.shape + (2, 1) + + ``axis`` may also be a tuple: + + >>> y = xpx.expand_dims(x, axis=(0, 1), xp=xp) + >>> y + Array([[[1, 2]]], dtype=array_api_strict.int64) + + >>> y = xpx.expand_dims(x, axis=(2, 0), xp=xp) + >>> y + Array([[[1], + [2]]], dtype=array_api_strict.int64) + """ + if xp is None: + xp = array_namespace(a) + + if not isinstance(axis, tuple): + axis = (axis,) + ndim = a.ndim + len(axis) + if axis != () and (min(axis) < -ndim or max(axis) >= ndim): + err_msg = ( + f"a provided axis position is out of bounds for array of dimension {a.ndim}" + ) + raise IndexError(err_msg) + axis = tuple(dim % ndim for dim in axis) + if len(set(axis)) != len(axis): + err_msg = "Duplicate dimensions specified in `axis`." + raise ValueError(err_msg) + + if is_numpy_namespace(xp) or is_dask_namespace(xp) or is_jax_namespace(xp): + return xp.expand_dims(a, axis=axis) + + return _funcs.expand_dims(a, axis=axis, xp=xp) def isclose( diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index f61affe5..53b7c7e0 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -493,87 +493,9 @@ def default_dtype( raise ValueError(msg) from e -def expand_dims( - a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None -) -> Array: - """ - Expand the shape of an array. - - Insert (a) new axis/axes that will appear at the position(s) specified by - `axis` in the expanded array shape. - - This is ``xp.expand_dims`` for `axis` an int *or a tuple of ints*. - Roughly equivalent to ``numpy.expand_dims`` for NumPy arrays. - - Parameters - ---------- - a : array - Array to have its shape expanded. - axis : int or tuple of ints, optional - Position(s) in the expanded axes where the new axis (or axes) is/are placed. - If multiple positions are provided, they should be unique (note that a position - given by a positive index could also be referred to by a negative index - - that will also result in an error). - Default: ``(0,)``. - xp : array_namespace, optional - The standard-compatible namespace for `a`. Default: infer. - - Returns - ------- - array - `a` with an expanded shape. - - Examples - -------- - >>> import array_api_strict as xp - >>> import array_api_extra as xpx - >>> x = xp.asarray([1, 2]) - >>> x.shape - (2,) - - The following is equivalent to ``x[xp.newaxis, :]`` or ``x[xp.newaxis]``: - - >>> y = xpx.expand_dims(x, axis=0, xp=xp) - >>> y - Array([[1, 2]], dtype=array_api_strict.int64) - >>> y.shape - (1, 2) - - The following is equivalent to ``x[:, xp.newaxis]``: - - >>> y = xpx.expand_dims(x, axis=1, xp=xp) - >>> y - Array([[1], - [2]], dtype=array_api_strict.int64) - >>> y.shape - (2, 1) - - ``axis`` may also be a tuple: - - >>> y = xpx.expand_dims(x, axis=(0, 1), xp=xp) - >>> y - Array([[[1, 2]]], dtype=array_api_strict.int64) - - >>> y = xpx.expand_dims(x, axis=(2, 0), xp=xp) - >>> y - Array([[[1], - [2]]], dtype=array_api_strict.int64) - """ - if xp is None: - xp = array_namespace(a) - - if not isinstance(axis, tuple): - axis = (axis,) - ndim = a.ndim + len(axis) - if axis != () and (min(axis) < -ndim or max(axis) >= ndim): - err_msg = ( - f"a provided axis position is out of bounds for array of dimension {a.ndim}" - ) - raise IndexError(err_msg) - axis = tuple(dim % ndim for dim in axis) - if len(set(axis)) != len(axis): - err_msg = "Duplicate dimensions specified in `axis`." - raise ValueError(err_msg) +def expand_dims(a: Array, /, *, axis: tuple[int, ...] = (0,), xp: ModuleType) -> Array: + # numpydoc ignore=PR01,RT01 + """See docstring in array_api_extra._delegation.""" for i in sorted(axis): a = xp.expand_dims(a, axis=i) return a