Skip to content

Commit 0ea829f

Browse files
committed
fix older JAX
1 parent 4f86a99 commit 0ea829f

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

array_api_compat/common/_helpers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -800,8 +800,11 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
800800
raise ValueError(f"Unsupported device {device!r}")
801801
elif is_jax_array(x):
802802
if not hasattr(x, "__array_namespace__"):
803-
# In JAX v0.4.31 and older, this import adds to_device method to x.
803+
# In JAX v0.4.31 and older, this import adds to_device method to x...
804804
import jax.experimental.array_api # noqa: F401
805+
# ... but only on eager JAX. It won't work inside jax.jit.
806+
if not hasattr(x, "to_device"):
807+
return x
805808
return x.to_device(device, stream=stream)
806809
elif is_pydata_sparse_array(x) and device == _device(x):
807810
# Perform trivial check to return the same array if

tests/test_jax.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,17 @@
77

88
def test_device_jit():
99
# Test work around to https://github.com/jax-ml/jax/issues/26000
10-
@jax.jit
10+
# Also test missing to_device() method in JAX < 0.4.31
11+
# when inside jax.jit, even after importing jax.experimental.array_api
12+
1113
def f(x):
1214
return jnp.zeros(1, device=device(x))
1315

14-
@jax.jit
1516
def g(x):
1617
return to_device(jnp.zeros(1), device(x))
1718

1819
x = jnp.ones(1)
1920
assert_equal(f(x), jnp.asarray(0))
2021
assert_equal(g(x), jnp.asarray(0))
22+
assert_equal(jax.jit(f)(x), jnp.asarray(0))
23+
assert_equal(jax.jit(g)(x), jnp.asarray(0))

0 commit comments

Comments
 (0)