Skip to content

Commit 7934ae1

Browse files
committed
fix older JAX
1 parent 4f86a99 commit 7934ae1

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

array_api_compat/common/_helpers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -800,8 +800,11 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
800800
raise ValueError(f"Unsupported device {device!r}")
801801
elif is_jax_array(x):
802802
if not hasattr(x, "__array_namespace__"):
803-
# In JAX v0.4.31 and older, this import adds to_device method to x.
803+
# In JAX v0.4.31 and older, this import adds to_device method to x...
804804
import jax.experimental.array_api # noqa: F401
805+
# ... but only on eager JAX. It won't work inside jax.jit.
806+
if not hasattr(x, "to_device"):
807+
return x
805808
return x.to_device(device, stream=stream)
806809
elif is_pydata_sparse_array(x) and device == _device(x):
807810
# Perform trivial check to return the same array if

0 commit comments

Comments
 (0)