Replies: 3 comments 5 replies
-
Do you currently know the best practices for implementing multi-host data parallelism with JAX? |
Beta Was this translation helpful? Give feedback.
-
Just change the code from: arr = jax.device_put(arr, NamedSharding(mesh, P(*partition_spec))) To: arr = jax.make_array_from_callback(arr.shape, NamedSharding(mesh, P(*partition_spec)), lambda idx: arr[idx]) See #20041 |
Beta Was this translation helpful? Give feedback.
-
If your input pipeline is fully data parallel, take a look at the docstring of: https://jax.readthedocs.io/en/latest/_autosummary/jax.make_array_from_single_device_arrays.html If am also planning to expose a helper function to do just this! Are you also asking on how to shard the weights? |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I'm using TPUs, and wished to do training across multi-node/hosts. The official docs suggest using
xmap
/pmap
.However, I'm using the sharding API to shard across multiple local devices.
So how can we extend the sharding to accomodate a multi-node setup?
AIUI, we should be able to to provide a sort of 3D sharding like (2, 8, 1) for 2x TPUs with 8 local devices each, DDP styled.
This would allow us to switch between n-way data parallelism and m-way model parallelism as outlined here.
But this doesn't seem to be the case?
Related: SO bounty thread I started here.
Beta Was this translation helpful? Give feedback.
All reactions