-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Description
Similar to #32399, the following code produces a `JaxprInputEffect Read<1> is invalid.' error,
import jax
import jax.numpy as jnp
@jax.jit
def bar(w):
x = jnp.zeros((1,)) + jnp.array([0])
x = jax.lax.cond(x[0] < 1, lambda x: x + w[...], lambda x: x - 1, x)
return x
@jax.jit
def foo(w):
return bar(w)
foo(jax.array_ref(jnp.ones((1,))))
System info (python version, jaxlib version, accelerator, etc.)
Am using a dev build that I verified includes that patches for #32399
jax: 0.8.0.dev20251007
jaxlib: 0.8.0.dev20251007
numpy: 2.3.3
python: 3.13.5 (main, Jul 11 2025, 22:45:47) [Clang 20.1.4 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='asahi', release='6.16.5-asahi', version='#1-NixOS SMP Tue Jan 1 00:00:00 UTC 1980', machine='aarch64')
Metadata
Metadata
Labels
bugSomething isn't workingSomething isn't working