Skip to content

Commit d850d97

Browse files
committed
Moved elastic logic from launch_trainer_main to launch_trainer.
Added guards to only use fast-resume if the proxy backend is used.
1 parent d568ad1 commit d850d97

File tree

2 files changed

+22
-17
lines changed

2 files changed

+22
-17
lines changed

axlearn/common/launch_trainer.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,21 @@ def run_trainer(trainer_config: SpmdTrainer.Config) -> Any:
149149

150150
trainer: SpmdTrainer = trainer_config.instantiate(parent=None)
151151
prng_key = jax.random.PRNGKey(seed=FLAGS.trainer_prng_seed)
152-
output = trainer.run(prng_key)
152+
153+
if FLAGS.jax_backend == "proxy":
154+
# pylint: disable-next=import-error,import-outside-toplevel
155+
from pathwaysutils.elastic import manager
156+
elastic_manager = manager.Manager()
157+
while True:
158+
try:
159+
output = trainer.run(prng_key)
160+
except jax.errors.JaxRuntimeError as error:
161+
if not elastic_manager._is_error_due_to_slice_down(error):
162+
raise
163+
ten_minutes = 10 * 60
164+
elastic_manager.wait_for_slices(timeout=ten_minutes)
165+
else:
166+
output = trainer.run(prng_key)
167+
153168
measurement.record_event(measurement.Event.END_JOB)
154169
return output

axlearn/common/launch_trainer_main.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,15 @@
66

77
from axlearn.common import launch, launch_trainer, measurement
88
from axlearn.common.config import config_for_function
9-
from pathwaysutils.elastic import manager
109

1110

1211
def main(_):
13-
elastic_manager = manager.Manager()
14-
while True:
15-
try:
16-
measurement.initialize(flags.FLAGS)
17-
launch.setup()
18-
trainer_config = launch_trainer.get_trainer_config()
19-
trainer_config.set(recorder=config_for_function(lambda: measurement.global_recorder))
20-
measurement.start_monitoring()
21-
launch_trainer.run_trainer(trainer_config)
22-
break
23-
except jax.errors.JaxRuntimeError as error:
24-
if not elastic_manager._is_error_due_to_slice_down(error):
25-
raise
26-
ten_minutes = 10 * 60
27-
elastic_manager.wait_for_slices(timeout=ten_minutes)
12+
measurement.initialize(flags.FLAGS)
13+
launch.setup()
14+
trainer_config = launch_trainer.get_trainer_config()
15+
trainer_config.set(recorder=config_for_function(lambda: measurement.global_recorder))
16+
measurement.start_monitoring()
17+
launch_trainer.run_trainer(trainer_config)
2818

2919

3020
if __name__ == "__main__":

0 commit comments

Comments
 (0)