device_put + jit vs pmap for data parallel training of neural networks #16282
Unanswered
YunfanZhang42
asked this question in
Q&A
Replies: 1 comment
-
After some digging on previous issues, it seems that the difference is mostly in efficiency. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I have a question regarding the best practices of data parallel training of complicated neural networks. I am using JAX + Flax to train a ViT based model in a data parallel/SPMD manner on TPUs. After reading the documentation, I see two ways of performing SPMD for neural networks training:
flax.jax_utils.replicate
, usepmean
intrain_step
function to accumulate the gradients and batch statistics, and finally usepmap
to parallelize thetrain_step
function. This is what google-research/vision_transformer and google-research/scenic do, so I also implemented my code in this way.device_put
to shard the inputs and replicate the model parameters and optimizer states, write thetrain_step
function as usual, and thenjit
thetrain_step
. This method seems to be more elegant and does not require accumulating gradients and batch statistics manually, but I am not sure it is feasible/recommended for a complex neural network.I am wondering what would be the preferred way to perform SPMD training of neural nets moving forward. Also, I am wondering if more documentation on this would make sense. Thanks for your help!
Beta Was this translation helpful? Give feedback.
All reactions