You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In my HPC setup, I send a SIGUSR1 signal to the python process running my Jax code 60 seconds before the job gets preempted. This then allows me to save a checkpoint right before the job is cancelled.
For a single process, multi-GPU setup I have something like this:
importjaxdefsigterm_handler(signum, frame):
print("Python: Received SIGUSR1. Exiting gracefully...")
# Save checkpointsys.exit(0)
signal.signal(signal.SIGUSR1, sigterm_handler)
# my code that trains a neural network
Which works like a charm.
However, I'm having some trouble figuring out how do this when I'm dealing with a multi-process setup.
I somehow need to somehow send a signal to the GRPC service launched by jax.distributed.initialize() to sync the processes and exit, which then lets me save a checkpoint.
What is the recommended way of interrupting a multi-process job in Jax? Is there a nice way to sync all processes and exit cleanly? I was looking around jax.experimental.multihost_utils but I'm not sure what tool I'm looking for here.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
In my HPC setup, I send a
SIGUSR1
signal to the python process running my Jax code 60 seconds before the job gets preempted. This then allows me to save a checkpoint right before the job is cancelled.For a single process, multi-GPU setup I have something like this:
Which works like a charm.
However, I'm having some trouble figuring out how do this when I'm dealing with a multi-process setup.
I somehow need to somehow send a signal to the GRPC service launched by
jax.distributed.initialize()
to sync the processes and exit, which then lets me save a checkpoint.What is the recommended way of interrupting a multi-process job in Jax? Is there a nice way to sync all processes and exit cleanly? I was looking around
jax.experimental.multihost_utils
but I'm not sure what tool I'm looking for here.Beta Was this translation helpful? Give feedback.
All reactions