From 387293f9be99b9acfba6d5adfe04df373bf45295 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 10 Oct 2025 02:59:17 +0000 Subject: [PATCH] [mutable-arrays] fix batching bugs in ref primitives A few of these are drive-bys and may not have test coverage... Co-authored-by: Georg Stefan Schmid --- jax/_src/state/primitives.py | 48 +++++++++++++++++++++++++----------- tests/mutable_array_test.py | 33 +++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 14 deletions(-) diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 9754b155c706..5c50723f20b2 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -203,7 +203,7 @@ def ref_swap( value: Array, _function_name: str = "ref_swap", ) -> Array: - """Set an array value inplace while returning the existing value. + """Update an array value inplace while returning the previous value. This is equivalent to ``ref[idx], prev = value, ref[idx]`` while returning ``prev``, for a NumPy-style indexer ``idx``. @@ -246,7 +246,6 @@ def ref_swap( .. _Ref guide: https://docs.jax.dev/en/latest/array_refs.html """ - "Sets a ref's value as `ref[idx], prev = value, ref[idx]` and returns `prev`." if hasattr(ref, 'dtype'): value = _maybe_implicit_cast(ref.dtype, value) ref, transforms = get_ref_and_transforms(ref, idx, _function_name) @@ -673,6 +672,7 @@ def _array_ref_batched(axis_data, vals_in, dims_in, memory_space): val, = vals_in dim, = dims_in if dim is None: + # We defensively batch the ref, b/c it could later be hit with a batched val val2 = batching.broadcast(val, axis_data.size, 0, axis_data.explicit_mesh_axis) return core.ref_p.bind(val2, memory_space=memory_space), 0 @@ -680,6 +680,12 @@ def _array_ref_batched(axis_data, vals_in, dims_in, memory_space): return core.ref_p.bind(val, memory_space=memory_space), dim batching.fancy_primitive_batchers[core.ref_p] = _array_ref_batched +def _freeze_batched(axis_data, vals_in, dims_in): + ref, = vals_in + dim, = dims_in + return core.freeze_p.bind(ref), dim +batching.fancy_primitive_batchers[core.freeze_p] = _freeze_batched + def _state_partial_eval_custom(saveable, unks_in, inst_in, eqn): del saveable # ignored, always full remat state ops on known inputs ref_unk, *_ = unks_in @@ -816,6 +822,7 @@ def _get_vmap(batched_args, batched_dims, *, tree): for i_dim in flat_idx_dims) if len(indexers) > 1: raise NotImplementedError("Batching with multiple indexers not supported.") + # TODO(sharadmv): handle vmap of multiple indexers new_indexers = tuple(_batch_indexer(indexer, dims, axis_size, ref.shape, ref_dim, idx_is_batched) @@ -858,9 +865,7 @@ def _get_vmap(batched_args, batched_dims, *, tree): return out, out_bdim batching.primitive_batchers[get_p] = _get_vmap -def _swap_vmap(batched_args, batched_dims, *, tree): - axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims) - if d is not batching.not_mapped} +def _swap_vmap(axis_data, batched_args, batched_dims, *, tree): ref, val, *flat_idxs = batched_args ref_dim, val_dim, *flat_idx_dims = batched_dims indexers = tree_util.tree_unflatten(tree, flat_idxs) @@ -877,11 +882,14 @@ def _swap_vmap(batched_args, batched_dims, *, tree): "Move the array reference to be an argument to the vmapped " "function?") if not indexers: + if ref_is_batched and not val_is_batched: + val = batching.broadcast(val, axis_data.size, ref_dim, + axis_data.explicit_mesh_axis) return swap_p.bind(ref, val, *flat_idxs, tree=tree), ref_dim if len(indexers) > 1: raise NotImplementedError("Batching with multiple indexers not supported.") # TODO(sharadmv): handle vmap of multiple indexers - new_indexers = tuple(_batch_indexer(indexer, dims, axis_size, + new_indexers = tuple(_batch_indexer(indexer, dims, axis_data.size, ref.shape, ref_dim, idx_is_batched) for indexer, dims in zip(indexers, indexers_dims)) flat_indexers, tree = tree_util.tree_flatten(new_indexers) @@ -905,7 +913,8 @@ def _swap_vmap(batched_args, batched_dims, *, tree): if not val_is_batched: if ref_is_batched or idx_is_batched: - val = batching.broadcast(val, axis_size, batched_dim_in_result, None) + val = batching.broadcast(val, axis_data.size, batched_dim_in_result, + axis_data.explicit_mesh_axis) else: val = batching.moveaxis(val, val_dim, batched_dim_in_result) @@ -937,11 +946,9 @@ def _swap_vmap(batched_args, batched_dims, *, tree): out = out.transpose(transpose_order_inversed) return out, batched_dim_in_result -batching.primitive_batchers[swap_p] = _swap_vmap +batching.fancy_primitive_batchers[swap_p] = _swap_vmap -def _addupdate_vmap(batched_args, batched_dims, *, tree): - axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims) - if d is not batching.not_mapped} +def _addupdate_vmap(axis_data, batched_args, batched_dims, *, tree): ref, val, *flat_idxs = batched_args ref_dim, val_dim, *flat_idx_dims = batched_dims indexers = tree_util.tree_unflatten(tree, flat_idxs) @@ -951,10 +958,22 @@ def _addupdate_vmap(batched_args, batched_dims, *, tree): val_is_batched = val_dim is not batching.not_mapped idx_is_batched = any(i_dim is not batching.not_mapped for i_dim in flat_idx_dims) + + if not ref_is_batched: + raise Exception("performing an addupdate operation with vmapped value on " + f"an unbatched array reference of type {core.typeof(ref)}. " + "Move the array reference to be an argument to the vmapped " + "function?") + if not indexers: + if ref_is_batched and not val_is_batched: + val = batching.broadcast(val, axis_data.size, ref_dim, + axis_data.explicit_mesh_axis) + return addupdate_p.bind(ref, val, *flat_idxs, tree=tree), [] if len(indexers) > 1: raise NotImplementedError("Batching with multiple indexers not supported.") + # TODO(sharadmv): handle vmap of multiple indexers - new_indexers = tuple(_batch_indexer(indexer, dims, axis_size, + new_indexers = tuple(_batch_indexer(indexer, dims, axis_data.size, ref.shape, ref_dim, idx_is_batched) for indexer, dims in zip(indexers, indexers_dims)) flat_indexers, tree = tree_util.tree_flatten(new_indexers) @@ -978,7 +997,8 @@ def _addupdate_vmap(batched_args, batched_dims, *, tree): if not val_is_batched: if ref_is_batched or idx_is_batched: - val = batching.broadcast(val, axis_size, batched_dim_in_result, None) + val = batching.broadcast(val, axis_data.size, batched_dim_in_result, + axis_data.explicit_mesh_axis) else: val = batching.moveaxis(val, val_dim, batched_dim_in_result) @@ -999,7 +1019,7 @@ def _addupdate_vmap(batched_args, batched_dims, *, tree): val = val.transpose(transpose_order) return addupdate_p.bind(ref, val, *flat_indexers, tree=tree), [] -batching.primitive_batchers[addupdate_p] = _addupdate_vmap +batching.fancy_primitive_batchers[addupdate_p] = _addupdate_vmap # Currently, JAX doesn't have a primitive that does an equal-rank broadcast. # We could use `jnp.broadcast_to` but that lowers to squeezing, diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index 0d027c0e413d..ca209250ca65 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -917,6 +917,39 @@ def body(_, xy): _, = f_vjp.with_refs(grad_accum)(1.) self.assertAllClose(grad_accum[...], jnp.arange(5.)) + def test_vmap_with_vjp3(self): + # https://github.com/jax-ml/jax/issues/32479 + def grad_via_ref(f): + def wrapper(*args): + grad_accum = jax.tree.map(lambda x: jax.new_ref(jnp.zeros_like(x)), args) + out, f_vjp = vjp3(f, *args) + f_vjp.with_refs(*grad_accum)(jnp.ones_like(out)) + return jax.tree.map(lambda x: jax.freeze(x), grad_accum) + return wrapper + + def issue_vmap1(): + def f(x): + return x + 1 + x = jnp.ones((4,)) + # g = grad_via_ref(jax.vmap(f)) # good + g = jax.vmap(grad_via_ref(f)) # bad + g(x) # crash + + def issue_vmap1_minimized(): + def f(x): + x.addupdate(1.0) # bad (assumes non-empty list of indexers) + jax.vmap(f)(jax.new_ref(jnp.zeros((4,)))) # crash + + def issue_vmap2(): + def f(x): + x[...] = 1.0 # bad (mismatched shapes) + jax.vmap(f)(jax.new_ref(jnp.zeros((4,)))) # crash + + # don't crash + issue_vmap1() + issue_vmap1_minimized() + issue_vmap2() + @jtu.with_config(jax_mutable_array_checks=True) class MutableArrayErrorsTest(jtu.JaxTestCase):