Skip to content

Commit e94e404

Browse files
committed
Merge branch 'main' into at
2 parents 8a686f8 + 6699efb commit e94e404

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/array_api_extra/_funcs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from collections.abc import Callable # pylint: disable=import-error
88
from typing import ClassVar, Literal
99

10-
from ._lib import _utils
10+
from ._lib import _compat, _utils
1111
from ._lib._compat import (
1212
array_namespace,
1313
is_jax_array,
@@ -212,7 +212,7 @@ def create_diagonal(
212212
err_msg = "`x` must be 1-dimensional."
213213
raise ValueError(err_msg)
214214
n = x.shape[0] + abs(offset)
215-
diag = xp.zeros(n**2, dtype=x.dtype, device=x.device)
215+
diag = xp.zeros(n**2, dtype=x.dtype, device=_compat.device(x))
216216
i = offset if offset >= 0 else abs(offset) * n
217217
diag[i : min(n * (n - offset), diag.shape[0]) : n + 1] = x
218218
return xp.reshape(diag, (n, n))
@@ -552,7 +552,7 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
552552
y = xp.pi * xp.where(
553553
xp.astype(x, xp.bool),
554554
x,
555-
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device),
555+
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)),
556556
)
557557
return xp.sin(y) / y
558558

0 commit comments

Comments
 (0)