Skip to content

Commit 8025b8a

Browse files
committed
Adding basic elastic training
1 parent 8d1ed78 commit 8025b8a

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

axlearn/common/launch_trainer_main.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,24 @@
66

77
from axlearn.common import launch, launch_trainer, measurement
88
from axlearn.common.config import config_for_function
9+
from pathwaysutils.elastic import manager
910

1011

1112
def main(_):
12-
measurement.initialize(flags.FLAGS)
13-
launch.setup()
14-
trainer_config = launch_trainer.get_trainer_config()
15-
trainer_config.set(recorder=config_for_function(lambda: measurement.global_recorder))
16-
measurement.start_monitoring()
17-
launch_trainer.run_trainer(trainer_config)
13+
elastic_manager = manager.Manager()
14+
while True:
15+
try:
16+
measurement.initialize(flags.FLAGS)
17+
launch.setup()
18+
trainer_config = launch_trainer.get_trainer_config()
19+
trainer_config.set(recorder=config_for_function(lambda: measurement.global_recorder))
20+
measurement.start_monitoring()
21+
launch_trainer.run_trainer(trainer_config)
22+
break
23+
except jax.errors.JaxRuntimeError as error:
24+
if not elastic_manager._is_error_due_to_slice_down(error):
25+
raise
26+
elastic_manager.wait_for_slices(timeout)
1827

1928

2029
if __name__ == "__main__":

0 commit comments

Comments
 (0)