55
66import math
77import warnings
8- from collections .abc import Sequence
8+ from collections .abc import Generator , Sequence
99from types import ModuleType
1010from typing import cast
1111
@@ -163,6 +163,16 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
163163 return xp .squeeze (c , axis = axes )
164164
165165
166+ def ndindex (* x : int ) -> Generator [tuple [int , ...]]:
167+ if not x :
168+ yield ()
169+ return
170+ indices = list (ndindex (* x [1 :]))
171+ for i in range (x [0 ]):
172+ for j in indices :
173+ yield i , * j
174+
175+
166176def create_diagonal (
167177 x : Array , / , * , offset : int = 0 , xp : ModuleType | None = None
168178) -> Array :
@@ -172,7 +182,7 @@ def create_diagonal(
172182 Parameters
173183 ----------
174184 x : array
175- A 1-D array.
185+ An array having shape (*broadcast_dims, k) .
176186 offset : int, optional
177187 Offset from the leading diagonal (default is ``0``).
178188 Use positive ints for diagonals above the leading diagonal,
@@ -183,7 +193,8 @@ def create_diagonal(
183193 Returns
184194 -------
185195 array
186- A 2-D array with `x` on the diagonal (offset by `offset`).
196+ An array having shape (*broadcast_dims, k+abs(offset), k+abs(offset)) with `x`
197+ on the diagonal (offset by `offset`).
187198
188199 Examples
189200 --------
@@ -206,18 +217,21 @@ def create_diagonal(
206217 if xp is None :
207218 xp = array_namespace (x )
208219
209- if x .ndim != 1 :
210- err_msg = "`x` must be 1-dimensional."
220+ if x .ndim == 0 :
221+ err_msg = "`x` must be at least 1-dimensional."
211222 raise ValueError (err_msg )
212- n = x .shape [0 ] + abs (offset )
213- diag = xp .zeros (n ** 2 , dtype = x .dtype , device = _compat .device (x ))
214-
215- start = offset if offset >= 0 else abs (offset ) * n
216- stop = min (n * (n - offset ), diag .shape [0 ])
217- step = n + 1
218- diag = at (diag )[start :stop :step ].set (x )
219-
220- return xp .reshape (diag , (n , n ))
223+ pre = x .shape [:- 1 ]
224+ n = x .shape [- 1 ] + abs (offset )
225+ diag = xp .zeros ((* pre , n ** 2 ), dtype = x .dtype , device = _compat .device (x ))
226+
227+ target_slice = slice (
228+ offset if offset >= 0 else abs (offset ) * n ,
229+ min (n * (n - offset ), diag .shape [- 1 ]),
230+ n + 1 ,
231+ )
232+ for index in ndindex (* pre ):
233+ diag = at (diag )[(* index , target_slice )].set (x [* index , :])
234+ return xp .reshape (diag , (* pre , n , n ))
221235
222236
223237def expand_dims (
0 commit comments