Skip to content

Bug in handling of effects with branching and constvars #32446

@pfrommerd

Description

@pfrommerd

Description

Similar to #32399, the following code produces a `JaxprInputEffect Read<1> is invalid.' error,

import jax
import jax.numpy as jnp


@jax.jit
def bar(w):
    x = jnp.zeros((1,)) + jnp.array([0])
    x = jax.lax.cond(x[0] < 1, lambda x: x + w[...], lambda x: x - 1, x)
    return x


@jax.jit
def foo(w):
    return bar(w)


foo(jax.array_ref(jnp.ones((1,))))

System info (python version, jaxlib version, accelerator, etc.)

Am using a dev build that I verified includes that patches for #32399

jax: 0.8.0.dev20251007
jaxlib: 0.8.0.dev20251007
numpy: 2.3.3
python: 3.13.5 (main, Jul 11 2025, 22:45:47) [Clang 20.1.4 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='asahi', release='6.16.5-asahi', version='#1-NixOS SMP Tue Jan 1 00:00:00 UTC 1980', machine='aarch64')

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions