-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Description
import jax
import jax.numpy as jnp
from jax._src.api import vjp3
def grad_via_ref(f):
def wrapper(*args):
grad_accum = jax.tree.map(lambda x: jax.new_ref(jnp.zeros_like(x)), args)
out, f_vjp = vjp3(f, *args)
f_vjp.with_refs(*grad_accum)(jnp.ones_like(out))
return jax.tree.map(lambda x: jax.freeze(x), grad_accum)
return wrapper
def issue_vmap1():
def f(x):
return x + 1
x = jnp.ones((4,))
# g = grad_via_ref(jax.vmap(f)) # good
g = jax.vmap(grad_via_ref(f)) # bad
g(x) # crash
def issue_vmap1_minimized():
def f(x):
x.addupdate(1.0) # bad (assumes non-empty list of indexers)
jax.vmap(f)(jax.new_ref(jnp.zeros((4,)))) # crash
def issue_vmap2():
def f(x):
x[...] = 1.0 # bad (mismatched shapes)
jax.vmap(f)(jax.new_ref(jnp.zeros((4,)))) # crash
issue_vmap1()
issue_vmap1_minimized()
# Fails with the following error:
# ...
# File "/opt/repos/jax/jax/_src/interpreters/batching.py", line 589, in process_primitive
# val_out, dim_out = primitive_batchers[p](vals_in, dims_in, **params)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# File "/opt/repos/jax/jax/_src/state/primitives.py", line 965, in _addupdate_vmap
# is_int_indexing, _, _ = indexing.unpack_ndindexer(indexers[0])
# ~~~~~~~~^^^
issue_vmap2()
# Fails with the following error:
# ...
# File "/opt/repos/jax/jax/_src/interpreters/partial_eval.py", line 2186, in default_process_primitive
# out_avals, effs = primitive.abstract_eval(*aval_qdds, **params)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# File "/opt/repos/jax/jax/_src/state/primitives.py", line 453, in _swap_abstract_eval
# raise ValueError("Invalid shape for `swap`. "
# ValueError: Invalid shape for `swap`. Ref shape: (4,). Expected shape: (4,). Value shape: (). Transforms: ().
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.8.0.dev20251009+188b796b8
jaxlib: 0.8.0.dev20251009
numpy: 2.0.2
python: 3.12.3 (main, Aug 14 2025, 17:47:21) [GCC 13.3.0]
device info: ...
process_count: 1
platform: ...
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working