99from types import ModuleType
1010from typing import cast
1111
12+ import numpy as np
13+
1214from ._at import at
1315from ._utils import _compat , _helpers
1416from ._utils ._compat import array_namespace , is_jax_array
@@ -172,7 +174,7 @@ def create_diagonal(
172174 Parameters
173175 ----------
174176 x : array
175- A 1-D array.
177+ An array having shape (*broadcast_dims, k) .
176178 offset : int, optional
177179 Offset from the leading diagonal (default is ``0``).
178180 Use positive ints for diagonals above the leading diagonal,
@@ -183,7 +185,8 @@ def create_diagonal(
183185 Returns
184186 -------
185187 array
186- A 2-D array with `x` on the diagonal (offset by `offset`).
188+ An array having shape (*broadcast_dims, k+abs(offset), k+abs(offset)) with `x`
189+ on the diagonal (offset by `offset`).
187190
188191 Examples
189192 --------
@@ -206,18 +209,20 @@ def create_diagonal(
206209 if xp is None :
207210 xp = array_namespace (x )
208211
209- if x .ndim != 1 :
210- err_msg = "`x` must be 1-dimensional."
212+ if x .ndim == 0 :
213+ err_msg = "`x` must be at least 1-dimensional."
211214 raise ValueError (err_msg )
212- n = x .shape [0 ] + abs (offset )
213- diag = xp .zeros (n ** 2 , dtype = x .dtype , device = _compat .device (x ))
214-
215- start = offset if offset >= 0 else abs (offset ) * n
216- stop = min (n * (n - offset ), diag .shape [0 ])
217- step = n + 1
218- diag = at (diag )[start :stop :step ].set (x )
219-
220- return xp .reshape (diag , (n , n ))
215+ pre = x .shape [:- 1 ]
216+ n = x .shape [- 1 ] + abs (offset )
217+ diag = xp .zeros ((* pre , n ** 2 ), dtype = x .dtype , device = _compat .device (x ))
218+
219+ target_slice = slice (offset if offset >= 0 else abs (offset ) * n ,
220+ min (n * (n - offset ), diag .shape [0 ]),
221+ n + 1 )
222+ for index in np .ndindex (* pre ):
223+ indexed_x = x [* index , :]
224+ diag = at (diag )[(* index , target_slice )].set (indexed_x )
225+ return xp .reshape (diag , (* pre , n , n ))
221226
222227
223228def expand_dims (
0 commit comments