Skip to content

Commit 9f4cb70

Browse files
committed
Adding basic elastic training
Pulling pathwaysutils from github Added guards to only use fast-resume if the proxy backend is used. Added the changes to the jobset for elastic training Temporary changes to the configuration to decrease batch size Adding a stop_trace to cancel any ongoing traces Taking checkpoint every 20_000_000 steps
1 parent 253d77f commit 9f4cb70

File tree

4 files changed

+49
-9
lines changed

4 files changed

+49
-9
lines changed

axlearn/cloud/gcp/pathways_utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,8 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
295295
f"--server_port={_PATHWAYS_PROXY_PORT}",
296296
f"--gcs_scratch_location={staging_location}",
297297
"--temporary_flags_for_debugging=temporary_flag_for_debugging_pipe_break_on_missing_keepalive=true",
298+
# This should be made configurable
299+
f"--num_elastic_slices={cfg.accelerator.num_replicas}",
298300
]
299301
cmd_args.extend(xla_flags_from_options(self._xla_options).split())
300302

@@ -566,14 +568,19 @@ def _build_pathways_worker_job(
566568
annotations.update(
567569
{"alpha.jobset.sigs.k8s.io/exclusive-topology": "cloud.google.com/gke-nodepool"}
568570
)
571+
# Default value for suspend and resume.
572+
# References:
573+
# https://github.com/google/pathways-job/blob/4417de7aa23d3c2316e400a3a327512834374475/internal/controller/pathwaysjob_controller.go#L651
574+
# backoffLimit = system.vms_per_slice * 4
575+
576+
# This backoffLimit is just for verifying elastic fast-resume
577+
large_number = 1000
578+
backoffLimit = system.vms_per_slice * 4 * large_number
569579

570580
spec = dict(
571581
parallelism=system.vms_per_slice,
572582
completions=system.vms_per_slice,
573-
# Default value for suspend and resume.
574-
# References:
575-
# https://github.com/google/pathways-job/blob/4417de7aa23d3c2316e400a3a327512834374475/internal/controller/pathwaysjob_controller.go#L651
576-
backoffLimit=system.vms_per_slice * 4,
583+
backoffLimit=backoffLimit,
577584
template=self._build_pathways_worker_pod(pathways_worker_replicated_job_index),
578585
)
579586
worker_job = dict(

axlearn/common/launch_trainer.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,41 @@ def _run_trainer_impl(trainer_config: SpmdTrainer.Config) -> Any:
148148
f,
149149
)
150150

151-
trainer: SpmdTrainer = trainer_config.instantiate(parent=None)
152-
prng_key = jax.random.PRNGKey(seed=FLAGS.trainer_prng_seed)
153-
return trainer.run(prng_key)
151+
if FLAGS.jax_backend == "proxy":
152+
# pylint: disable-next=import-error,import-outside-toplevel
153+
from pathwaysutils.elastic import manager
154+
elastic_manager = manager.Manager()
155+
while True:
156+
try:
157+
trainer: SpmdTrainer = trainer_config.instantiate(parent=None)
158+
prng_key = jax.random.PRNGKey(seed=FLAGS.trainer_prng_seed)
159+
output = trainer.run(prng_key)
160+
break
161+
except jax.errors.JaxRuntimeError as error:
162+
if not elastic_manager.is_error_due_to_slice_down(error):
163+
raise
164+
try:
165+
logging.info("Trying to clean up ongoing traces")
166+
jax.profiler.stop_trace()
167+
logging.info("Successfully cleaned up ongoing traces")
168+
except (RuntimeError, ValueError) as e:
169+
logging.info("No ongoing traces to clean up")
170+
except Exception as e:
171+
logging.exception("Error trying to clean up ongoing traces")
172+
raise
173+
174+
jax.clear_caches()
175+
for array in jax.live_arrays():
176+
array.delete()
177+
178+
ten_minutes = 10 * 60
179+
elastic_manager.wait_for_slices(timeout=ten_minutes)
180+
else:
181+
trainer: SpmdTrainer = trainer_config.instantiate(parent=None)
182+
prng_key = jax.random.PRNGKey(seed=FLAGS.trainer_prng_seed)
183+
output = trainer.run(prng_key)
184+
185+
return output
154186

155187

156188
def run_trainer(trainer_config: SpmdTrainer.Config) -> Any:

axlearn/experiments/text/gpt/fuji.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ def get_trainer_kwargs(
384384
max_sequence_length=max_sequence_length,
385385
train_batch_size=len(jax.devices()),
386386
max_step=max_step,
387+
save_every_n_steps=100,
387388
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8),
388389
mesh_rules=(
389390
# Step time:
@@ -848,7 +849,7 @@ def get_trainer_kwargs(
848849
max_sequence_length=max_sequence_length,
849850
train_batch_size=train_batch_size, # number of devices times 4 chips per device times 4096 samples per chip # train_batch_size,
850851
max_step=10_000, # max_step,
851-
save_every_n_steps=1000,
852+
save_every_n_steps=20_000_000,
852853
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=64, model=4),
853854
mesh_rules=(
854855
(

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ tpu = [
114114
pathways-tpu = [
115115
"axlearn[gcp]",
116116
"jax==0.5.3", # must be >=0.4.19 for compat with v5p.
117-
"pathwaysutils==0.1.1",
117+
"pathwaysutils @ git+https://github.com/AI-Hypercomputer/pathways-utils",
118118
]
119119
# Vertex AI tensorboard. TODO(markblee): Merge with `gcp`.
120120
vertexai_tensorboard = [

0 commit comments

Comments
 (0)