1212from ._at import at
1313from ._utils import _compat , _helpers
1414from ._utils ._compat import array_namespace , is_jax_array
15- from ._utils ._helpers import asarrays
15+ from ._utils ._helpers import asarrays , ndindex
1616from ._utils ._typing import Array
1717
1818__all__ = [
2929
3030
3131def atleast_nd (x : Array , / , * , ndim : int , xp : ModuleType | None = None ) -> Array :
32- """
33- Recursively expand the dimension of an array to at least `ndim`.
32+ """Recursively expand the dimension of an array to at least `ndim`.
3433
3534 Parameters
3635 ----------
@@ -72,8 +71,7 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array
7271
7372
7473def cov (m : Array , / , * , xp : ModuleType | None = None ) -> Array :
75- """
76- Estimate a covariance matrix.
74+ """Estimate a covariance matrix.
7775
7876 Covariance indicates the level to which two variables vary together.
7977 If we examine N-dimensional samples, :math:`X = [x_1, x_2, ... x_N]^T`,
@@ -166,13 +164,12 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
166164def create_diagonal (
167165 x : Array , / , * , offset : int = 0 , xp : ModuleType | None = None
168166) -> Array :
169- """
170- Construct a diagonal array.
167+ """Construct a diagonal array.
171168
172169 Parameters
173170 ----------
174171 x : array
175- A 1-D array.
172+ An array having shape (*broadcast_dims, k) .
176173 offset : int, optional
177174 Offset from the leading diagonal (default is ``0``).
178175 Use positive ints for diagonals above the leading diagonal,
@@ -183,7 +180,8 @@ def create_diagonal(
183180 Returns
184181 -------
185182 array
186- A 2-D array with `x` on the diagonal (offset by `offset`).
183+ An array having shape (*broadcast_dims, k+abs(offset), k+abs(offset)) with `x`
184+ on the diagonal (offset by `offset`).
187185
188186 Examples
189187 --------
@@ -206,25 +204,27 @@ def create_diagonal(
206204 if xp is None :
207205 xp = array_namespace (x )
208206
209- if x .ndim != 1 :
210- err_msg = "`x` must be 1-dimensional."
207+ if x .ndim == 0 :
208+ err_msg = "`x` must be at least 1-dimensional."
211209 raise ValueError (err_msg )
212- n = x .shape [0 ] + abs (offset )
213- diag = xp .zeros (n ** 2 , dtype = x .dtype , device = _compat .device (x ))
214-
215- start = offset if offset >= 0 else abs (offset ) * n
216- stop = min (n * (n - offset ), diag .shape [0 ])
217- step = n + 1
218- diag = at (diag )[start :stop :step ].set (x )
219-
220- return xp .reshape (diag , (n , n ))
210+ pre = x .shape [:- 1 ]
211+ n = x .shape [- 1 ] + abs (offset )
212+ diag = xp .zeros ((* pre , n ** 2 ), dtype = x .dtype , device = _compat .device (x ))
213+
214+ target_slice = slice (
215+ offset if offset >= 0 else abs (offset ) * n ,
216+ min (n * (n - offset ), diag .shape [- 1 ]),
217+ n + 1 ,
218+ )
219+ for index in ndindex (* pre ):
220+ diag = at (diag )[(* index , target_slice )].set (x [* index , :])
221+ return xp .reshape (diag , (* pre , n , n ))
221222
222223
223224def expand_dims (
224225 a : Array , / , * , axis : int | tuple [int , ...] = (0 ,), xp : ModuleType | None = None
225226) -> Array :
226- """
227- Expand the shape of an array.
227+ """Expand the shape of an array.
228228
229229 Insert (a) new axis/axes that will appear at the position(s) specified by
230230 `axis` in the expanded array shape.
0 commit comments