Skip to content

Commit 31ce85e

Browse files
committed
Hack around JAX device attribute
1 parent c04816a commit 31ce85e

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

src/array_api_extra/_lib/_utils/_compat.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
"""Acquire helpers from array-api-compat."""
2+
23
# Allow packages that vendor both `array-api-extra` and
34
# `array-api-compat` to override the import location
5+
from __future__ import annotations
6+
7+
from typing import TYPE_CHECKING
8+
9+
if TYPE_CHECKING:
10+
from ._typing import Array, Device
411

512
try:
613
from ...._array_api_compat_vendor import (
714
array_namespace,
8-
device,
915
is_array_api_strict_namespace,
1016
is_cupy_namespace,
1117
is_dask_namespace,
@@ -17,10 +23,12 @@
1723
is_writeable_array,
1824
size,
1925
)
26+
from ...._array_api_compat_vendor import (
27+
device as _compat_device,
28+
)
2029
except ImportError:
2130
from array_api_compat import (
2231
array_namespace,
23-
device,
2432
is_array_api_strict_namespace,
2533
is_cupy_namespace,
2634
is_dask_namespace,
@@ -32,6 +40,18 @@
3240
is_writeable_array,
3341
size,
3442
)
43+
from array_api_compat import (
44+
device as _compat_device,
45+
)
46+
47+
48+
def device(x: Array) -> Device | None: # numpydoc ignore=GL08
49+
try:
50+
return _compat_device(x)
51+
except AttributeError:
52+
assert is_jax_array(x)
53+
return None
54+
3555

3656
__all__ = [
3757
"array_namespace",

0 commit comments

Comments
 (0)