diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 07f7d552..aa12b0c8 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -4,6 +4,7 @@ argpartition, atleast_nd, cov, + create_diagonal, expand_dims, isclose, isin, @@ -18,7 +19,6 @@ from ._lib._funcs import ( apply_where, broadcast_shapes, - create_diagonal, default_dtype, kron, nunique, diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index 289d21e4..cfaf5c89 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -21,6 +21,7 @@ __all__ = [ "atleast_nd", "cov", + "create_diagonal", "expand_dims", "isclose", "nan_to_num", @@ -174,6 +175,67 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: return _funcs.cov(m, xp=xp) +def create_diagonal( + x: Array, /, *, offset: int = 0, xp: ModuleType | None = None +) -> Array: + """ + Construct a diagonal array. + + Parameters + ---------- + x : array + An array having shape ``(*batch_dims, k)``. + offset : int, optional + Offset from the leading diagonal (default is ``0``). + Use positive ints for diagonals above the leading diagonal, + and negative ints for diagonals below the leading diagonal. + xp : array_namespace, optional + The standard-compatible namespace for `x`. Default: infer. + + Returns + ------- + array + An array having shape ``(*batch_dims, k+abs(offset), k+abs(offset))`` with `x` + on the diagonal (offset by `offset`). + + Examples + -------- + >>> import array_api_strict as xp + >>> import array_api_extra as xpx + >>> x = xp.asarray([2, 4, 8]) + + >>> xpx.create_diagonal(x, xp=xp) + Array([[2, 0, 0], + [0, 4, 0], + [0, 0, 8]], dtype=array_api_strict.int64) + + >>> xpx.create_diagonal(x, offset=-2, xp=xp) + Array([[0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [2, 0, 0, 0, 0], + [0, 4, 0, 0, 0], + [0, 0, 8, 0, 0]], dtype=array_api_strict.int64) + """ + if xp is None: + xp = array_namespace(x) + + if x.ndim == 0: + err_msg = "`x` must be at least 1-dimensional." + raise ValueError(err_msg) + + if is_torch_namespace(xp): + return xp.diag_embed(x, offset=offset, dim1=-2, dim2=-1) + + if (is_dask_namespace(xp) or is_cupy_namespace(xp)) and x.ndim < 2: + return xp.diag(x, k=offset) + + if (is_jax_namespace(xp) or is_numpy_namespace(xp)) and x.ndim < 3: + batch_dim, n = eager_shape(x)[:-1], eager_shape(x, -1)[0] + abs(offset) + return xp.reshape(xp.diag(x, k=offset), (*batch_dim, n, n)) + + return _funcs.create_diagonal(x, offset=offset, xp=xp) + + def expand_dims( a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None ) -> Array: diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 49840c0f..6e50ce95 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -295,53 +295,9 @@ def one_hot( def create_diagonal( - x: Array, /, *, offset: int = 0, xp: ModuleType | None = None -) -> Array: - """ - Construct a diagonal array. - - Parameters - ---------- - x : array - An array having shape ``(*batch_dims, k)``. - offset : int, optional - Offset from the leading diagonal (default is ``0``). - Use positive ints for diagonals above the leading diagonal, - and negative ints for diagonals below the leading diagonal. - xp : array_namespace, optional - The standard-compatible namespace for `x`. Default: infer. - - Returns - ------- - array - An array having shape ``(*batch_dims, k+abs(offset), k+abs(offset))`` with `x` - on the diagonal (offset by `offset`). - - Examples - -------- - >>> import array_api_strict as xp - >>> import array_api_extra as xpx - >>> x = xp.asarray([2, 4, 8]) - - >>> xpx.create_diagonal(x, xp=xp) - Array([[2, 0, 0], - [0, 4, 0], - [0, 0, 8]], dtype=array_api_strict.int64) - - >>> xpx.create_diagonal(x, offset=-2, xp=xp) - Array([[0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [2, 0, 0, 0, 0], - [0, 4, 0, 0, 0], - [0, 0, 8, 0, 0]], dtype=array_api_strict.int64) - """ - if xp is None: - xp = array_namespace(x) - - if x.ndim == 0: - err_msg = "`x` must be at least 1-dimensional." - raise ValueError(err_msg) - + x: Array, /, *, offset: int = 0, xp: ModuleType +) -> Array: # numpydoc ignore=PR01,RT01 + """See docstring in array_api_extra._delegation.""" x_shape = eager_shape(x) batch_dims = x_shape[:-1] n = x_shape[-1] + abs(offset)