From 640ecfe653120e485fec3f971b23a7f0798a6ac5 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Tue, 24 Mar 2026 04:13:21 -0700 Subject: [PATCH] [pmap] In-line definitions of `jax.device_put_sharded` and `jax.device_put_replicated`. Both `jax.device_put_sharded` and `jax.device_put_replicated` were deprecated in JAX v0.8.1 in November 2025. We in-line their definitions using public JAX APIs taking the `jax_pmap_shmap_merge=True` branch, which was made the default in JAX v0.8.0 in October 2025. Please see the below for more information: - JAX CHANGELOG: https://docs.jax.dev/en/latest/changelog.html - Migrating from `jax.pmap`: https://docs.jax.dev/en/latest/migrate_pmap.html PiperOrigin-RevId: 888582438 --- flax/jax_utils.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/flax/jax_utils.py b/flax/jax_utils.py index c62ecec55..1e4c3f1a1 100644 --- a/flax/jax_utils.py +++ b/flax/jax_utils.py @@ -42,7 +42,15 @@ def replicate(tree, devices=None): A new pytree containing the replicated arrays. """ devices = devices or _pmap_device_order() - return jax.device_put_replicated(tree, devices) + mesh = jax.sharding.Mesh(np.array(devices), ('_device_put_sharded',)) + sharding = jax.NamedSharding(mesh, jax.P('_device_put_sharded')) + + def _replicate(x): + if isinstance(x, jax.Array): + return jax.device_put(jnp.stack([x] * len(devices)), sharding) + return jax.device_put(np.stack([x] * len(devices)), sharding) + + return jax.tree_util.tree_map(_replicate, tree) def unreplicate(tree): @@ -153,7 +161,11 @@ def prefetch_to_device(iterator, size, devices=None): devices = _pmap_device_order() if devices is None else devices def _prefetch(xs): - return jax.device_put_sharded(list(xs), devices) + mesh = jax.sharding.Mesh(np.array(devices), ('_device_put_sharded',)) + sharding = jax.NamedSharding(mesh, jax.P('_device_put_sharded')) + if isinstance(xs, jax.Array): + return jax.device_put(jnp.stack(list(xs)), sharding) + return jax.device_put(np.stack(list(xs)), sharding) def enqueue(n): # Enqueues *up to* `n` elements from the iterator. for data in itertools.islice(iterator, n):