-
The code below can implement 1D tensor parallelism across multiple devices with TPU v3-8. import jax
from jax.experimental import mesh_utils
import jax.numpy as jnp
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
arr = jnp.arange(32*4).reshape(32, 4)
n_devices = jax.device_count()
mesh_shape = [n_devices, 1]
axis_names = ('a1', 'a2')
partition_spec = ('a1', 'a2')
devices = mesh_utils.create_device_mesh(mesh_shape)
mesh = Mesh(devices, axis_names=axis_names)
arr = jax.device_put(arr, NamedSharding(mesh, P(*partition_spec)))
jax.debug.visualize_array_sharding(arr) However, when attempting to run the same code on TPU v4-32 across 4 hosts, it didn't work as expected. I encountered the following error:
I wonder if the problem is due to |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
Hi! So the problem here is that device_put cannot transfer across hosts (we know about this and we are looking into improving the situation here). On single host, it works out as you know but will fail on multiple hosts. A better thing here is to use
|
Beta Was this translation helpful? Give feedback.
-
See the docs here if you want to do fully data parallel input loading and how to create an Array for that: https://jax.readthedocs.io/en/latest/_autosummary/jax.make_array_from_single_device_arrays.html |
Beta Was this translation helpful? Give feedback.
Hi!
So the problem here is that device_put cannot transfer across hosts (we know about this and we are looking into improving the situation here). On single host, it works out as you know but will fail on multiple hosts.
A better thing here is to use
jax.make_array_from_callback
because the input on every host is the same i.e. it'sarr = jnp.arange(32*4).reshape(32, 4)
.make_array_from_callback
will carve out the shards that each device needs on that host. Here how the code will look: