File tree Expand file tree Collapse file tree 1 file changed +18
-2
lines changed Expand file tree Collapse file tree 1 file changed +18
-2
lines changed Original file line number Diff line number Diff line change 11"""Acquire helpers from array-api-compat."""
2+
23# Allow packages that vendor both `array-api-extra` and
34# `array-api-compat` to override the import location
5+ from __future__ import annotations
6+
7+ from typing import TYPE_CHECKING
8+
9+ if TYPE_CHECKING :
10+ from ._typing import Array , Device
411
512try :
613 from ...._array_api_compat_vendor import (
714 array_namespace ,
8- device ,
915 is_array_api_strict_namespace ,
1016 is_cupy_namespace ,
1117 is_dask_namespace ,
1723 is_writeable_array ,
1824 size ,
1925 )
26+ from ...._array_api_compat_vendor import (
27+ device as _compat_device ,
28+ )
2029except ImportError :
2130 from array_api_compat import (
2231 array_namespace ,
23- device ,
2432 is_array_api_strict_namespace ,
2533 is_cupy_namespace ,
2634 is_dask_namespace ,
3240 is_writeable_array ,
3341 size ,
3442 )
43+ from array_api_compat import (
44+ device as _compat_device ,
45+ )
46+
47+
48+ def device (x : Array ) -> Device | None : # numpydoc ignore=GL08
49+ return None if is_jax_array (x ) else _compat_device (x )
50+
3551
3652__all__ = [
3753 "array_namespace" ,
You can’t perform that action at this time.
0 commit comments