File tree Expand file tree Collapse file tree 2 files changed +22
-17
lines changed Expand file tree Collapse file tree 2 files changed +22
-17
lines changed Original file line number Diff line number Diff line change @@ -149,6 +149,21 @@ def run_trainer(trainer_config: SpmdTrainer.Config) -> Any:
149
149
150
150
trainer : SpmdTrainer = trainer_config .instantiate (parent = None )
151
151
prng_key = jax .random .PRNGKey (seed = FLAGS .trainer_prng_seed )
152
- output = trainer .run (prng_key )
152
+
153
+ if FLAGS .jax_backend == "proxy" :
154
+ # pylint: disable-next=import-error,import-outside-toplevel
155
+ from pathwaysutils .elastic import manager
156
+ elastic_manager = manager .Manager ()
157
+ while True :
158
+ try :
159
+ output = trainer .run (prng_key )
160
+ except jax .errors .JaxRuntimeError as error :
161
+ if not elastic_manager ._is_error_due_to_slice_down (error ):
162
+ raise
163
+ ten_minutes = 10 * 60
164
+ elastic_manager .wait_for_slices (timeout = ten_minutes )
165
+ else :
166
+ output = trainer .run (prng_key )
167
+
153
168
measurement .record_event (measurement .Event .END_JOB )
154
169
return output
Original file line number Diff line number Diff line change 6
6
7
7
from axlearn .common import launch , launch_trainer , measurement
8
8
from axlearn .common .config import config_for_function
9
- from pathwaysutils .elastic import manager
10
9
11
10
12
11
def main (_ ):
13
- elastic_manager = manager .Manager ()
14
- while True :
15
- try :
16
- measurement .initialize (flags .FLAGS )
17
- launch .setup ()
18
- trainer_config = launch_trainer .get_trainer_config ()
19
- trainer_config .set (recorder = config_for_function (lambda : measurement .global_recorder ))
20
- measurement .start_monitoring ()
21
- launch_trainer .run_trainer (trainer_config )
22
- break
23
- except jax .errors .JaxRuntimeError as error :
24
- if not elastic_manager ._is_error_due_to_slice_down (error ):
25
- raise
26
- ten_minutes = 10 * 60
27
- elastic_manager .wait_for_slices (timeout = ten_minutes )
12
+ measurement .initialize (flags .FLAGS )
13
+ launch .setup ()
14
+ trainer_config = launch_trainer .get_trainer_config ()
15
+ trainer_config .set (recorder = config_for_function (lambda : measurement .global_recorder ))
16
+ measurement .start_monitoring ()
17
+ launch_trainer .run_trainer (trainer_config )
28
18
29
19
30
20
if __name__ == "__main__" :
You can’t perform that action at this time.
0 commit comments