You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is not only a question but rather an opportunity to discuss ways to do SPMD efficiently in jax (especially for TPU backends).
The personal motivation for this is to understand why pmap on 8 replicas on 8 devices is always consistently slower than jit on 1 replica on 1 device, even in the best condition of no inter-communication across replicas (e.g. no pmean, psum, etc...).
So, imaging a simple training loop similar to the mnist example, where params and batch have the leading dimension equal to the number of devices:
Now, of course this is wrong but it makes it clear that the process is IO bounded and the bottleneck is indeed going from np.array to jax.ShardedDeviceArray.
So, first question is: why and is this expected? I'm not the first one to experience this behavior (see e.g. #6631, #2459, #6626, #8281), which might indicate that either we are doing something wrong or that there is a problem with shared arrays.
Where are the shared arrays stored?
I can follow-up with an issue, but I think this is the best place to discuss it.
In a more abstract way, what I would like to do is to replicate the data only on the device that eventually it will use it. Apparently this is what it implemented in the class GlobalDeviceArray: citing from the doc, "A GlobalDeviceArray (GDA) can be thought of as a view into a single logical array sharded across processes [...]. Each process can only directly access the shards of the global array data stored on its local devices".
This requires to work with meshes, partitions, pjit... and honestly I'm now completely lost.
Can you guys shed some light on the state of SPMD in jax and how to achieve it efficiently?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
This is not only a question but rather an opportunity to discuss ways to do SPMD efficiently in jax (especially for TPU backends).
The personal motivation for this is to understand why
pmap
on 8 replicas on 8 devices is always consistently slower thanjit
on 1 replica on 1 device, even in the best condition of no inter-communication across replicas (e.g. nopmean
,psum
, etc...).So, imaging a simple training loop similar to the mnist example, where
params
andbatch
have the leading dimension equal to the number of devices:Now, pushing the
batch
on the devices with something like the function below (from the Flax library) recovers the single device performanceNow, of course this is wrong but it makes it clear that the process is IO bounded and the bottleneck is indeed going from
np.array
tojax.ShardedDeviceArray
.So, first question is: why and is this expected? I'm not the first one to experience this behavior (see e.g. #6631, #2459, #6626, #8281), which might indicate that either we are doing something wrong or that there is a problem with shared arrays.
Where are the shared arrays stored?
I can follow-up with an issue, but I think this is the best place to discuss it.
In a more abstract way, what I would like to do is to replicate the data only on the device that eventually it will use it. Apparently this is what it implemented in the class
GlobalDeviceArray
: citing from the doc, "A GlobalDeviceArray (GDA) can be thought of as a view into a single logical array sharded across processes [...]. Each process can only directly access the shards of the global array data stored on its local devices".This requires to work with meshes, partitions, pjit... and honestly I'm now completely lost.
Can you guys shed some light on the state of SPMD in jax and how to achieve it efficiently?
Beta Was this translation helpful? Give feedback.
All reactions