Fault-tolerant training with jax.distributed
#16619
Unanswered
superMDguy
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
The docs say that jax will "ensure that all processes shut down if any process dies". I'm working on getting a training run setup with preemptible TPUs, and it's likely that nodes will go offline during training. It would be great if Jax could pause operations when a node dies so I can automatically create a new node from the latest checkpoint, instead of having to reload all of the running instances.
Beta Was this translation helpful? Give feedback.
All reactions