How to do distributed training? #9375
Unanswered
ayaka14732
asked this question in
Q&A
Replies: 2 comments
-
check out https://github.com/sholtodouglas/scalingExperiments for data and tensor parallelism. pipeline parallelism through k8s/ray is coming soon-ish according to the repo owner |
Beta Was this translation helpful? Give feedback.
0 replies
-
Check out Distributed training with JAX & Flax: https://www.machinelearningnuggets.com/distributed-training-with-jax-and-flax/ |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I have been searching hard to find a tutorial about doing distributed training in JAX (e.g. with 100 v2-8 Cloud TPUs). It seems that Ray can achieve this goal (mesh-transformer-jax, swarm-jax), but I don't quite understand how to make it work.
Beta Was this translation helpful? Give feedback.
All reactions