Skip to content

Commit 7a1fc85

Browse files
committed
Adding basic elastic training
Pulling pathwaysutils from github Added guards to only use fast-resume if the proxy backend is used. Added the changes to the jobset for elastic training Temporary changes to the configuration to decrease batch size Adding a stop_trace to cancel any ongoing traces Taking checkpoint every 100 steps
1 parent 924db3c commit 7a1fc85

File tree

4 files changed

+48
-8
lines changed

4 files changed

+48
-8
lines changed

axlearn/cloud/gcp/pathways_utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,8 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
320320
f"--resource_manager_address=localhost:{_PATHWAYS_RESOURCE_MANAGER_PORT}",
321321
f"--server_port={_PATHWAYS_PROXY_PORT}",
322322
f"--gcs_scratch_location={staging_location}",
323+
# This should be made configurable
324+
f"--num_elastic_slices={cfg.accelerator.num_replicas}",
323325
]
324326
cmd_args.extend(xla_flags_from_options(self._xla_options).split())
325327

@@ -588,14 +590,19 @@ def _build_pathways_worker_job(
588590
annotations.update(
589591
{"alpha.jobset.sigs.k8s.io/exclusive-topology": "cloud.google.com/gke-nodepool"}
590592
)
593+
# Default value for suspend and resume.
594+
# References:
595+
# https://github.com/google/pathways-job/blob/4417de7aa23d3c2316e400a3a327512834374475/internal/controller/pathwaysjob_controller.go#L651
596+
# backoffLimit = system.vms_per_slice * 4
597+
598+
# This backoffLimit is just for verifying elastic fast-resume
599+
large_number = 1000
600+
backoffLimit = system.vms_per_slice * 4 * large_number
591601

592602
spec = dict(
593603
parallelism=system.vms_per_slice,
594604
completions=system.vms_per_slice,
595-
# Default value for suspend and resume.
596-
# References:
597-
# https://github.com/google/pathways-job/blob/4417de7aa23d3c2316e400a3a327512834374475/internal/controller/pathwaysjob_controller.go#L651
598-
backoffLimit=system.vms_per_slice * 4,
605+
backoffLimit=backoffLimit,
599606
template=self._build_pathways_worker_pod(pathways_worker_replicated_job_index),
600607
)
601608
worker_job = dict(

axlearn/common/launch_trainer.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,41 @@ def _run_trainer_impl(trainer_config: SpmdTrainer.Config) -> Any:
148148
f,
149149
)
150150

151-
trainer: SpmdTrainer = trainer_config.instantiate(parent=None)
152-
prng_key = jax.random.PRNGKey(seed=FLAGS.trainer_prng_seed)
153-
return trainer.run(prng_key)
151+
if FLAGS.jax_backend == "proxy":
152+
# pylint: disable-next=import-error,import-outside-toplevel
153+
from pathwaysutils.elastic import manager
154+
elastic_manager = manager.Manager()
155+
while True:
156+
try:
157+
trainer: SpmdTrainer = trainer_config.instantiate(parent=None)
158+
prng_key = jax.random.PRNGKey(seed=FLAGS.trainer_prng_seed)
159+
output = trainer.run(prng_key)
160+
break
161+
except jax.errors.JaxRuntimeError as error:
162+
if not elastic_manager.is_error_due_to_slice_down(error):
163+
raise
164+
try:
165+
logging.info("Trying to clean up ongoing traces")
166+
jax.profiler.stop_trace()
167+
logging.info("Successfully cleaned up ongoing traces")
168+
except (RuntimeError, ValueError) as e:
169+
logging.info("No ongoing traces to clean up")
170+
except Exception as e:
171+
logging.exception("Error trying to clean up ongoing traces")
172+
raise
173+
174+
jax.clear_caches()
175+
for array in jax.live_arrays():
176+
array.delete()
177+
178+
ten_minutes = 10 * 60
179+
elastic_manager.wait_for_slices(timeout=ten_minutes)
180+
else:
181+
trainer: SpmdTrainer = trainer_config.instantiate(parent=None)
182+
prng_key = jax.random.PRNGKey(seed=FLAGS.trainer_prng_seed)
183+
output = trainer.run(prng_key)
184+
185+
return output
154186

155187

156188
def run_trainer(trainer_config: SpmdTrainer.Config) -> Any:

axlearn/experiments/text/gpt/fuji.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ def get_trainer_kwargs(
384384
max_sequence_length=max_sequence_length,
385385
train_batch_size=len(jax.devices()),
386386
max_step=max_step,
387+
save_every_n_steps=100,
387388
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=8),
388389
mesh_rules=(
389390
# Step time:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ tpu = [
114114
pathways-tpu = [
115115
"axlearn[gcp]",
116116
"jax==0.5.3", # must be >=0.4.19 for compat with v5p.
117-
"pathwaysutils==0.1.1",
117+
"pathwaysutils @ git+https://github.com/AI-Hypercomputer/pathways-utils",
118118
]
119119
# Vertex AI tensorboard. TODO(markblee): Merge with `gcp`.
120120
vertexai_tensorboard = [

0 commit comments

Comments
 (0)