File tree Expand file tree Collapse file tree 1 file changed +15
-6
lines changed Expand file tree Collapse file tree 1 file changed +15
-6
lines changed Original file line number Diff line number Diff line change 6
6
7
7
from axlearn .common import launch , launch_trainer , measurement
8
8
from axlearn .common .config import config_for_function
9
+ from pathwaysutils .elastic import manager
9
10
10
11
11
12
def main (_ ):
12
- measurement .initialize (flags .FLAGS )
13
- launch .setup ()
14
- trainer_config = launch_trainer .get_trainer_config ()
15
- trainer_config .set (recorder = config_for_function (lambda : measurement .global_recorder ))
16
- measurement .start_monitoring ()
17
- launch_trainer .run_trainer (trainer_config )
13
+ elastic_manager = manager .Manager ()
14
+ while True :
15
+ try :
16
+ measurement .initialize (flags .FLAGS )
17
+ launch .setup ()
18
+ trainer_config = launch_trainer .get_trainer_config ()
19
+ trainer_config .set (recorder = config_for_function (lambda : measurement .global_recorder ))
20
+ measurement .start_monitoring ()
21
+ launch_trainer .run_trainer (trainer_config )
22
+ break
23
+ except jax .errors .JaxRuntimeError as error :
24
+ if not elastic_manager ._is_error_due_to_slice_down (error ):
25
+ raise
26
+ elastic_manager .wait_for_slices (timeout )
18
27
19
28
20
29
if __name__ == "__main__" :
You can’t perform that action at this time.
0 commit comments