diff --git a/flax/jax_utils.py b/flax/jax_utils.py index c62ecec55..1e4c3f1a1 100644 --- a/flax/jax_utils.py +++ b/flax/jax_utils.py @@ -42,7 +42,15 @@ def replicate(tree, devices=None): A new pytree containing the replicated arrays. """ devices = devices or _pmap_device_order() - return jax.device_put_replicated(tree, devices) + mesh = jax.sharding.Mesh(np.array(devices), ('_device_put_sharded',)) + sharding = jax.NamedSharding(mesh, jax.P('_device_put_sharded')) + + def _replicate(x): + if isinstance(x, jax.Array): + return jax.device_put(jnp.stack([x] * len(devices)), sharding) + return jax.device_put(np.stack([x] * len(devices)), sharding) + + return jax.tree_util.tree_map(_replicate, tree) def unreplicate(tree): @@ -153,7 +161,11 @@ def prefetch_to_device(iterator, size, devices=None): devices = _pmap_device_order() if devices is None else devices def _prefetch(xs): - return jax.device_put_sharded(list(xs), devices) + mesh = jax.sharding.Mesh(np.array(devices), ('_device_put_sharded',)) + sharding = jax.NamedSharding(mesh, jax.P('_device_put_sharded')) + if isinstance(xs, jax.Array): + return jax.device_put(jnp.stack(list(xs)), sharding) + return jax.device_put(np.stack(list(xs)), sharding) def enqueue(n): # Enqueues *up to* `n` elements from the iterator. for data in itertools.islice(iterator, n):