Skip to content

Commit 8da2691

Browse files
committed
Use the public API for is_error_due_to_slice_down from pathwaysutils
1 parent 1af5203 commit 8da2691

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

axlearn/common/launch_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def run_trainer(trainer_config: SpmdTrainer.Config) -> Any:
147147
f,
148148
)
149149

150-
if FLAGS.jax_backend == "proxy":
150+
if False and FLAGS.jax_backend == "proxy":
151151
# pylint: disable-next=import-error,import-outside-toplevel
152152
from pathwaysutils.elastic import manager
153153
elastic_manager = manager.Manager()
@@ -158,7 +158,7 @@ def run_trainer(trainer_config: SpmdTrainer.Config) -> Any:
158158
output = trainer.run(prng_key)
159159
break
160160
except jax.errors.JaxRuntimeError as error:
161-
if not elastic_manager._is_error_due_to_slice_down(error):
161+
if not elastic_manager.is_error_due_to_slice_down(error):
162162
raise
163163
ten_minutes = 10 * 60
164164
elastic_manager.wait_for_slices(timeout=ten_minutes)

0 commit comments

Comments
 (0)