Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 34 additions & 14 deletions jax/_src/state/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def ref_swap(
value: Array,
_function_name: str = "ref_swap",
) -> Array:
"""Set an array value inplace while returning the existing value.
"""Update an array value inplace while returning the previous value.

This is equivalent to ``ref[idx], prev = value, ref[idx]`` while returning
``prev``, for a NumPy-style indexer ``idx``.
Expand Down Expand Up @@ -246,7 +246,6 @@ def ref_swap(

.. _Ref guide: https://docs.jax.dev/en/latest/array_refs.html
"""
"Sets a ref's value as `ref[idx], prev = value, ref[idx]` and returns `prev`."
if hasattr(ref, 'dtype'):
value = _maybe_implicit_cast(ref.dtype, value)
ref, transforms = get_ref_and_transforms(ref, idx, _function_name)
Expand Down Expand Up @@ -673,13 +672,20 @@ def _array_ref_batched(axis_data, vals_in, dims_in, memory_space):
val, = vals_in
dim, = dims_in
if dim is None:
# We defensively batch the ref, b/c it could later be hit with a batched val
val2 = batching.broadcast(val, axis_data.size, 0,
axis_data.explicit_mesh_axis)
return core.ref_p.bind(val2, memory_space=memory_space), 0
else:
return core.ref_p.bind(val, memory_space=memory_space), dim
batching.fancy_primitive_batchers[core.ref_p] = _array_ref_batched

def _freeze_batched(axis_data, vals_in, dims_in):
ref, = vals_in
dim, = dims_in
return core.freeze_p.bind(ref), dim
batching.fancy_primitive_batchers[core.freeze_p] = _freeze_batched

def _state_partial_eval_custom(saveable, unks_in, inst_in, eqn):
del saveable # ignored, always full remat state ops on known inputs
ref_unk, *_ = unks_in
Expand Down Expand Up @@ -816,6 +822,7 @@ def _get_vmap(batched_args, batched_dims, *, tree):
for i_dim in flat_idx_dims)
if len(indexers) > 1:
raise NotImplementedError("Batching with multiple indexers not supported.")

# TODO(sharadmv): handle vmap of multiple indexers
new_indexers = tuple(_batch_indexer(indexer, dims, axis_size,
ref.shape, ref_dim, idx_is_batched)
Expand Down Expand Up @@ -858,9 +865,7 @@ def _get_vmap(batched_args, batched_dims, *, tree):
return out, out_bdim
batching.primitive_batchers[get_p] = _get_vmap

def _swap_vmap(batched_args, batched_dims, *, tree):
axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims)
if d is not batching.not_mapped}
def _swap_vmap(axis_data, batched_args, batched_dims, *, tree):
ref, val, *flat_idxs = batched_args
ref_dim, val_dim, *flat_idx_dims = batched_dims
indexers = tree_util.tree_unflatten(tree, flat_idxs)
Expand All @@ -877,11 +882,14 @@ def _swap_vmap(batched_args, batched_dims, *, tree):
"Move the array reference to be an argument to the vmapped "
"function?")
if not indexers:
if ref_is_batched and not val_is_batched:
val = batching.broadcast(val, axis_data.size, ref_dim,
axis_data.explicit_mesh_axis)
return swap_p.bind(ref, val, *flat_idxs, tree=tree), ref_dim
if len(indexers) > 1:
raise NotImplementedError("Batching with multiple indexers not supported.")
# TODO(sharadmv): handle vmap of multiple indexers
new_indexers = tuple(_batch_indexer(indexer, dims, axis_size,
new_indexers = tuple(_batch_indexer(indexer, dims, axis_data.size,
ref.shape, ref_dim, idx_is_batched)
for indexer, dims in zip(indexers, indexers_dims))
flat_indexers, tree = tree_util.tree_flatten(new_indexers)
Expand All @@ -905,7 +913,8 @@ def _swap_vmap(batched_args, batched_dims, *, tree):

if not val_is_batched:
if ref_is_batched or idx_is_batched:
val = batching.broadcast(val, axis_size, batched_dim_in_result, None)
val = batching.broadcast(val, axis_data.size, batched_dim_in_result,
axis_data.explicit_mesh_axis)
else:
val = batching.moveaxis(val, val_dim, batched_dim_in_result)

Expand Down Expand Up @@ -937,11 +946,9 @@ def _swap_vmap(batched_args, batched_dims, *, tree):
out = out.transpose(transpose_order_inversed)

return out, batched_dim_in_result
batching.primitive_batchers[swap_p] = _swap_vmap
batching.fancy_primitive_batchers[swap_p] = _swap_vmap

def _addupdate_vmap(batched_args, batched_dims, *, tree):
axis_size, = {x.shape[d] for x, d in zip(batched_args, batched_dims)
if d is not batching.not_mapped}
def _addupdate_vmap(axis_data, batched_args, batched_dims, *, tree):
ref, val, *flat_idxs = batched_args
ref_dim, val_dim, *flat_idx_dims = batched_dims
indexers = tree_util.tree_unflatten(tree, flat_idxs)
Expand All @@ -951,10 +958,22 @@ def _addupdate_vmap(batched_args, batched_dims, *, tree):
val_is_batched = val_dim is not batching.not_mapped
idx_is_batched = any(i_dim is not batching.not_mapped
for i_dim in flat_idx_dims)

if not ref_is_batched:
raise Exception("performing an addupdate operation with vmapped value on "
f"an unbatched array reference of type {core.typeof(ref)}. "
"Move the array reference to be an argument to the vmapped "
"function?")
if not indexers:
if ref_is_batched and not val_is_batched:
val = batching.broadcast(val, axis_data.size, ref_dim,
axis_data.explicit_mesh_axis)
return addupdate_p.bind(ref, val, *flat_idxs, tree=tree), []
if len(indexers) > 1:
raise NotImplementedError("Batching with multiple indexers not supported.")

# TODO(sharadmv): handle vmap of multiple indexers
new_indexers = tuple(_batch_indexer(indexer, dims, axis_size,
new_indexers = tuple(_batch_indexer(indexer, dims, axis_data.size,
ref.shape, ref_dim, idx_is_batched)
for indexer, dims in zip(indexers, indexers_dims))
flat_indexers, tree = tree_util.tree_flatten(new_indexers)
Expand All @@ -978,7 +997,8 @@ def _addupdate_vmap(batched_args, batched_dims, *, tree):

if not val_is_batched:
if ref_is_batched or idx_is_batched:
val = batching.broadcast(val, axis_size, batched_dim_in_result, None)
val = batching.broadcast(val, axis_data.size, batched_dim_in_result,
axis_data.explicit_mesh_axis)
else:
val = batching.moveaxis(val, val_dim, batched_dim_in_result)

Expand All @@ -999,7 +1019,7 @@ def _addupdate_vmap(batched_args, batched_dims, *, tree):
val = val.transpose(transpose_order)

return addupdate_p.bind(ref, val, *flat_indexers, tree=tree), []
batching.primitive_batchers[addupdate_p] = _addupdate_vmap
batching.fancy_primitive_batchers[addupdate_p] = _addupdate_vmap

# Currently, JAX doesn't have a primitive that does an equal-rank broadcast.
# We could use `jnp.broadcast_to` but that lowers to squeezing,
Expand Down
33 changes: 33 additions & 0 deletions tests/mutable_array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,39 @@ def body(_, xy):
_, = f_vjp.with_refs(grad_accum)(1.)
self.assertAllClose(grad_accum[...], jnp.arange(5.))

def test_vmap_with_vjp3(self):
# https://github.com/jax-ml/jax/issues/32479
def grad_via_ref(f):
def wrapper(*args):
grad_accum = jax.tree.map(lambda x: jax.new_ref(jnp.zeros_like(x)), args)
out, f_vjp = vjp3(f, *args)
f_vjp.with_refs(*grad_accum)(jnp.ones_like(out))
return jax.tree.map(lambda x: jax.freeze(x), grad_accum)
return wrapper

def issue_vmap1():
def f(x):
return x + 1
x = jnp.ones((4,))
# g = grad_via_ref(jax.vmap(f)) # good
g = jax.vmap(grad_via_ref(f)) # bad
g(x) # crash

def issue_vmap1_minimized():
def f(x):
x.addupdate(1.0) # bad (assumes non-empty list of indexers)
jax.vmap(f)(jax.new_ref(jnp.zeros((4,)))) # crash

def issue_vmap2():
def f(x):
x[...] = 1.0 # bad (mismatched shapes)
jax.vmap(f)(jax.new_ref(jnp.zeros((4,)))) # crash

# don't crash
issue_vmap1()
issue_vmap1_minimized()
issue_vmap2()


@jtu.with_config(jax_mutable_array_checks=True)
class MutableArrayErrorsTest(jtu.JaxTestCase):
Expand Down
Loading