Skip to content

vmap of ref_{set,addupdate} is broken (missing indexers, mismatched shapes) #32479

@gspschmid

Description

@gspschmid

Description

import jax
import jax.numpy as jnp
from jax._src.api import vjp3

def grad_via_ref(f):
  def wrapper(*args):
    grad_accum = jax.tree.map(lambda x: jax.new_ref(jnp.zeros_like(x)), args)
    out, f_vjp = vjp3(f, *args)
    f_vjp.with_refs(*grad_accum)(jnp.ones_like(out))
    return jax.tree.map(lambda x: jax.freeze(x), grad_accum)
  return wrapper

def issue_vmap1():
  def f(x):
    return x + 1
  x = jnp.ones((4,))
  # g = grad_via_ref(jax.vmap(f))  # good
  g = jax.vmap(grad_via_ref(f))    # bad
  g(x)  # crash

def issue_vmap1_minimized():
  def f(x):
    x.addupdate(1.0)  # bad (assumes non-empty list of indexers)
  jax.vmap(f)(jax.new_ref(jnp.zeros((4,))))  # crash

def issue_vmap2():
  def f(x):
    x[...] = 1.0  # bad (mismatched shapes)
  jax.vmap(f)(jax.new_ref(jnp.zeros((4,))))  # crash

issue_vmap1()
issue_vmap1_minimized()
# Fails with the following error:
#   ...
#     File "/opt/repos/jax/jax/_src/interpreters/batching.py", line 589, in process_primitive
#       val_out, dim_out = primitive_batchers[p](vals_in, dims_in, **params)
#                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#     File "/opt/repos/jax/jax/_src/state/primitives.py", line 965, in _addupdate_vmap
#       is_int_indexing, _, _ = indexing.unpack_ndindexer(indexers[0])
#                                                         ~~~~~~~~^^^

issue_vmap2()
# Fails with the following error:
#   ...
#     File "/opt/repos/jax/jax/_src/interpreters/partial_eval.py", line 2186, in default_process_primitive
#       out_avals, effs = primitive.abstract_eval(*aval_qdds, **params)
#                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#     File "/opt/repos/jax/jax/_src/state/primitives.py", line 453, in _swap_abstract_eval
#       raise ValueError("Invalid shape for `swap`. "
#   ValueError: Invalid shape for `swap`. Ref shape: (4,). Expected shape: (4,). Value shape: (). Transforms: ().

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.8.0.dev20251009+188b796b8
jaxlib: 0.8.0.dev20251009
numpy:  2.0.2
python: 3.12.3 (main, Aug 14 2025, 17:47:21) [GCC 13.3.0]
device info: ...
process_count: 1
platform: ...

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions