1111
1212from ._at import at
1313from ._utils import _compat , _helpers
14+ from ._utils ._compat import array_namespace , is_jax_array
15+ from ._utils ._helpers import asarrays , ndindex
1416from ._utils ._compat import (
1517 array_namespace ,
1618 is_dask_namespace ,
@@ -384,7 +386,7 @@ def create_diagonal(
384386 Parameters
385387 ----------
386388 x : array
387- A 1-D array.
389+ An array having shape ``(*batch_dims, k)`` .
388390 offset : int, optional
389391 Offset from the leading diagonal (default is ``0``).
390392 Use positive ints for diagonals above the leading diagonal,
@@ -395,7 +397,8 @@ def create_diagonal(
395397 Returns
396398 -------
397399 array
398- A 2-D array with `x` on the diagonal (offset by `offset`).
400+ An array having shape ``(*batch_dims, k+abs(offset), k+abs(offset))`` with `x`
401+ on the diagonal (offset by `offset`).
399402
400403 Examples
401404 --------
@@ -418,18 +421,21 @@ def create_diagonal(
418421 if xp is None :
419422 xp = array_namespace (x )
420423
421- if x .ndim != 1 :
422- err_msg = "`x` must be 1-dimensional."
424+ if x .ndim == 0 :
425+ err_msg = "`x` must be at least 1-dimensional."
423426 raise ValueError (err_msg )
424- n = x .shape [0 ] + abs (offset )
425- diag = xp .zeros (n ** 2 , dtype = x .dtype , device = _compat .device (x ))
426-
427- start = offset if offset >= 0 else abs (offset ) * n
428- stop = min (n * (n - offset ), diag .shape [0 ])
429- step = n + 1
430- diag = at (diag )[start :stop :step ].set (x )
431-
432- return xp .reshape (diag , (n , n ))
427+ batch_dims = x .shape [:- 1 ]
428+ n = x .shape [- 1 ] + abs (offset )
429+ diag = xp .zeros ((* batch_dims , n ** 2 ), dtype = x .dtype , device = _compat .device (x ))
430+
431+ target_slice = slice (
432+ offset if offset >= 0 else abs (offset ) * n ,
433+ min (n * (n - offset ), diag .shape [- 1 ]),
434+ n + 1 ,
435+ )
436+ for index in ndindex (* batch_dims ):
437+ diag = at (diag )[(* index , target_slice )].set (x [(* index , slice (None ))])
438+ return xp .reshape (diag , (* batch_dims , n , n ))
433439
434440
435441def expand_dims (
0 commit comments