Skip to content

Commit 68fd521

Browse files
committed
fix bug with delegate create_diagonals.
1 parent 1132a27 commit 68fd521

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

src/array_api_extra/_delegation.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,13 +226,14 @@ def create_diagonal(
226226
if is_torch_namespace(xp):
227227
return xp.diag_embed(x, offset=offset, dim1=-2, dim2=-1)
228228

229-
if (is_dask_namespace(xp) or is_cupy_namespace(xp)) and x.ndim < 2:
229+
if (
230+
is_dask_namespace(xp)
231+
or is_cupy_namespace(xp)
232+
or is_numpy_namespace(xp)
233+
or is_jax_namespace(xp)
234+
) and (x.ndim < 2):
230235
return xp.diag(x, k=offset)
231236

232-
if (is_jax_namespace(xp) or is_numpy_namespace(xp)) and x.ndim < 3:
233-
batch_dim, n = eager_shape(x)[:-1], eager_shape(x, -1)[0] + abs(offset)
234-
return xp.reshape(xp.diag(x, k=offset), (*batch_dim, n, n))
235-
236237
return _funcs.create_diagonal(x, offset=offset, xp=xp)
237238

238239

tests/test_funcs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,7 @@ def test_0d_raises(self, xp: ModuleType):
711711
(0, 1),
712712
(1, 0),
713713
(0, 0),
714+
(2, 3),
714715
(4, 2, 1),
715716
(1, 1, 7),
716717
(0, 0, 1),

0 commit comments

Comments
 (0)