Skip to content

Commit 10a9ec0

Browse files
Edwinhr716changlan
authored andcommitted
lws integration
runner for lws added added updatig condition added spec for pathways in lws updated to variables refactored to match JobSet Pathways implementation added new runner for jetstream-pathways bug fixes, made num_replicas a parameter for an LWS object ran precommits added pathways utils tests added more unit tests removed mentions of jetstream, fixed bug in runner added tests for gke runner minor fixes cleaned up tests removed jetstream from runner name addressed comments test service and PR changes Update axlearn/cloud/gcp/job.py Co-authored-by: Meng (Ethan) Li <[email protected]> Update axlearn/cloud/gcp/job.py Co-authored-by: Meng (Ethan) Li <[email protected]> Update axlearn/cloud/gcp/job.py Co-authored-by: Meng (Ethan) Li <[email protected]> Update axlearn/cloud/gcp/job.py Co-authored-by: Meng (Ethan) Li <[email protected]> added flag added patch flatten logic, removed inner cleanup Leader on a seperate CPU node changed to subgroup exclusive policy added image fixed node selector on head pod added backend platforms env variable added jax backend target variable add rm port fixed issue with proxy error addressed comments Removed service, will add on different PR added env variable to fix handshake issue Update node_pool_provisioner.py fixed pre-commit error refactored to add TPUJobBuilder, addressed comments addressed final comments added resource requirements added worker image get accelerator from inner added lws test, refactored other tests rebase added builder type changed builder name changed builder name to add _ addressed comments fix for list command remove propagation policy from list command fixed LWS runner test
1 parent 4e69203 commit 10a9ec0

13 files changed

+1279
-20
lines changed

axlearn/cloud/gcp/job.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,13 @@
1919
from axlearn.cloud.common.utils import generate_job_name, subprocess_run
2020
from axlearn.cloud.gcp.config import default_env_id, default_project, default_zone
2121
from axlearn.cloud.gcp.jobset_utils import BaseReplicatedJob
22-
from axlearn.cloud.gcp.utils import custom_jobset_kwargs, delete_k8s_jobset
22+
from axlearn.cloud.gcp.lws_utils import BaseLeaderWorkerTemplate
23+
from axlearn.cloud.gcp.utils import (
24+
custom_jobset_kwargs,
25+
custom_leaderworkerset_kwargs,
26+
delete_k8s_jobset,
27+
delete_k8s_leaderworkerset,
28+
)
2329
from axlearn.common.config import REQUIRED, ConfigOr, Required, config_class, maybe_instantiate
2430
from axlearn.common.utils import Nested
2531

@@ -267,3 +273,103 @@ def docker_command(
267273
)
268274
logging.debug("Docker run command: %s", cmd)
269275
return cmd
276+
277+
278+
class GKELeaderWorkerSet(GCPJob):
279+
"""Base GKE LeaderWorkerSet interface"""
280+
281+
@config_class
282+
class Config(GCPJob.Config):
283+
"""Configures GKELeaderWorkerSet.
284+
Attributes:
285+
builder: A builder that returns one or more statefulset specs.
286+
namespace: The namespace to use within the k8s cluster.
287+
annotations: LeaderWorkerSet annotations.
288+
num_replicas: number of LWS replicas.
289+
"""
290+
291+
builder: Required[BaseLeaderWorkerTemplate.Config] = REQUIRED
292+
namespace: str = "default"
293+
annotations: Optional[ConfigOr[dict]] = None
294+
num_replicas: int = 1
295+
296+
@classmethod
297+
def set_defaults(cls, fv):
298+
super().set_defaults(fv)
299+
fv.set_default("max_tries", fv.max_tries or 10)
300+
fv.set_default("retry_interval", fv.retry_interval or 60)
301+
302+
@classmethod
303+
def define_flags(cls, fv: flags.FlagValues):
304+
super().define_flags(fv)
305+
common_kwargs = dict(flag_values=fv, allow_override=True)
306+
flags.DEFINE_string("name", None, "Name of the LeaderWorkerSet.", **common_kwargs)
307+
308+
@classmethod
309+
def from_flags(cls, fv: flags.FlagValues, **kwargs):
310+
cfg: GKELeaderWorkerSet.Config = super().from_flags(fv, **kwargs)
311+
cfg.num_replicas = fv.num_replicas
312+
return cfg
313+
314+
def __init__(self, cfg: Config, *, bundler: BaseDockerBundler):
315+
super().__init__(cfg)
316+
cfg: GKELeaderWorkerSet.Config = self.config
317+
self._bundler = bundler
318+
# This instantiatees a builder for constructing replicated job specs, which will be managed
319+
# together under the leaderworkerset represented by this class.
320+
# Note the distinction from bundlers, which are responsible for bundling any code assets
321+
# required to run the job.
322+
self._builder: BaseLeaderWorkerTemplate = cfg.builder.instantiate(bundler=bundler)
323+
324+
def _delete(self):
325+
cfg: GKELeaderWorkerSet.Config = self.config
326+
# Issues a delete request for the LeaderWorkerSet and proactively delete its descendants.
327+
# This is not fully blocking; after the call returns there can be a delay before
328+
# everything is deleted.
329+
delete_k8s_leaderworkerset(cfg.name, namespace=cfg.namespace)
330+
331+
def _build_leaderworkerset(self) -> Nested[Any]:
332+
"""
333+
Builds a config for a LeaderWorkerSet, which is a set for multi-host inference
334+
335+
Returns:
336+
A nested dict corresponding to a k8s LWS config
337+
"""
338+
cfg: GKELeaderWorkerSet.Config = self.config
339+
annotations = maybe_instantiate(cfg.annotations or {})
340+
341+
return dict(
342+
metadata=dict(name=cfg.name, annotations=annotations),
343+
spec=dict(
344+
replicas=cfg.num_replicas,
345+
leaderWorkerTemplate=self._builder(),
346+
),
347+
)
348+
349+
def _execute(self):
350+
cfg: GKELeaderWorkerSet.Config = self.config
351+
352+
api_kwargs = custom_leaderworkerset_kwargs()
353+
custom_object = dict(
354+
apiVersion=f"{api_kwargs['group']}/{api_kwargs['version']}",
355+
kind="LeaderWorkerSet",
356+
**self._build_leaderworkerset(),
357+
)
358+
logging.info("submitting LeaderWorkerSet: %s", custom_object)
359+
return k8s.client.CustomObjectsApi().create_namespaced_custom_object(
360+
namespace=cfg.namespace,
361+
body=custom_object,
362+
**api_kwargs,
363+
)
364+
365+
366+
def exclusive_topology_annotations_leaderworkerset() -> dict:
367+
"""Used for TPU GKELeaderWorkerSet.
368+
369+
The exclusive topology annotation will ensure that all Pods will have affinity
370+
rules added that will ensure that they are fully scheduled on the same pod-slice
371+
node-pools.
372+
"""
373+
return {
374+
"leaderworkerset.sigs.k8s.io/subgroup-exclusive-topology": "cloud.google.com/gke-nodepool"
375+
}

axlearn/cloud/gcp/job_test.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from axlearn.cloud.common.bundler import Bundler
1313
from axlearn.cloud.common.utils import define_flags, from_flags
14-
from axlearn.cloud.gcp import bundler, job, jobset_utils
14+
from axlearn.cloud.gcp import bundler, job, jobset_utils, pathways_utils
1515
from axlearn.cloud.gcp.bundler import ArtifactRegistryBundler, CloudBuildBundler
1616
from axlearn.cloud.gcp.test_utils import default_mock_settings, mock_gcp_settings
1717
from axlearn.common.config import REQUIRED, Required, config_class
@@ -211,3 +211,90 @@ def test_build_jobset(
211211
self.assertNotIn("kueue.x-k8s.io/queue-name", jobset_annotations)
212212
else:
213213
self.assertEqual(jobset_annotations["kueue.x-k8s.io/queue-name"], queue)
214+
215+
216+
class TPUGKELeaderWorkerSetTest(TestCase):
217+
"""Tests GKELeaderWorkerSet with TPU."""
218+
219+
def run(self, result=None):
220+
# Run tests under mock user and settings.
221+
self._settings = default_mock_settings()
222+
with mock_gcp_settings(
223+
[jobset_utils.__name__, bundler.__name__],
224+
settings=self._settings,
225+
):
226+
return super().run(result)
227+
228+
def _job_config(
229+
self,
230+
*,
231+
command: str,
232+
bundler_cls: type[Bundler],
233+
**kwargs,
234+
) -> tuple[job.GKELeaderWorkerSet.Config, Bundler.Config]:
235+
fv = flags.FlagValues()
236+
cfg = job.GKELeaderWorkerSet.default_config().set(
237+
builder=pathways_utils.PathwaysLeaderWorkerTemplate.default_config()
238+
)
239+
define_flags(cfg, fv)
240+
for key, value in kwargs.items():
241+
if value is not None:
242+
# Use setattr rather than set_default to set flags.
243+
setattr(fv, key, value)
244+
fv.name = "fake-name"
245+
fv.output_dir = "FAKE"
246+
fv.instance_type = "tpu-v4-8"
247+
fv.mark_as_parsed()
248+
from_flags(cfg, fv, command=command)
249+
# Test that retries are configured on fv by default.
250+
self.assertIsNotNone(fv["max_tries"].default)
251+
self.assertIsNotNone(fv["retry_interval"].default)
252+
bundler_cfg = bundler_cls.from_spec([], fv=fv).set(image="test-image")
253+
return cfg, bundler_cfg
254+
255+
@parameterized.product(
256+
reservation=[None, "test"],
257+
bundler_cls=[ArtifactRegistryBundler, CloudBuildBundler],
258+
wrap_bundler=[False, True],
259+
)
260+
def test_instantiate(
261+
self,
262+
reservation,
263+
bundler_cls: type[Bundler],
264+
wrap_bundler,
265+
):
266+
class WrappedBundler(Bundler):
267+
@config_class
268+
class Config(Bundler.Config):
269+
inner: Required[Bundler.Config] = REQUIRED
270+
271+
cfg, bundler_cfg = self._job_config(
272+
command="test-command",
273+
bundler_cls=bundler_cls,
274+
reservation=reservation,
275+
num_replicas=1,
276+
)
277+
278+
self.assertIsInstance(cfg.builder, pathways_utils.PathwaysLeaderWorkerTemplate.Config)
279+
cfg.builder = cast(pathways_utils.PathwaysLeaderWorkerTemplate.Config, cfg.builder)
280+
281+
self.assertEqual(cfg.name, cfg.builder.name)
282+
self.assertEqual(cfg.project, self._settings["project"])
283+
self.assertEqual(cfg.zone, self._settings["zone"])
284+
self.assertEqual(
285+
cfg.builder.inner.reservation, reservation or self._settings["gke_reservation"]
286+
)
287+
self.assertEqual(cfg.num_replicas, 1)
288+
# Should work with wrapped bundlers.
289+
if wrap_bundler:
290+
bundler_cfg = WrappedBundler.default_config().set(inner=bundler_cfg)
291+
gke_job = cfg.instantiate(bundler=bundler_cfg.instantiate())
292+
self.assertEqual("v4-8", gke_job._builder._tpu_type)
293+
294+
def test_delete(self):
295+
patch_delete = mock.patch(f"{job.__name__}.delete_k8s_leaderworkerset")
296+
with patch_delete as mock_delete:
297+
cfg, _ = self._job_config(command="test-command", bundler_cls=CloudBuildBundler)
298+
gke_job = cfg.instantiate(bundler=mock.Mock())
299+
gke_job._delete() # pylint: disable=protected-access
300+
mock_delete.assert_called()

axlearn/cloud/gcp/jobset_utils.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -308,12 +308,12 @@ def from_flags(cls, fv: flags.FlagValues, **kwargs):
308308
return cfg
309309

310310

311-
class TPUReplicatedJob(SingleReplicatedJob):
312-
"""Builds a replicated jobspec for TPU, to be used with JobSet API."""
311+
class TPUJobBuilder(SingleReplicatedJob):
312+
"""Common base class for TPU Specs"""
313313

314314
@config_class
315315
class Config(SingleReplicatedJob.Config):
316-
"""Configures TPUReplicatedJob.
316+
"""Configures TPUJobBuilder.
317317
318318
Attributes:
319319
reservation: If specified, the TPU reservation name. This is not necessarily specific to
@@ -380,7 +380,7 @@ def define_flags(cls, fv: flags.FlagValues):
380380

381381
@classmethod
382382
def from_flags(cls, fv: flags.FlagValues, **kwargs) -> Config:
383-
cfg: TPUReplicatedJob.Config = super().from_flags(fv, **kwargs)
383+
cfg: TPUJobBuilder.Config = super().from_flags(fv, **kwargs)
384384
default_env = get_default_env(
385385
tpu_type=infer_tpu_type(fv.instance_type),
386386
num_tpu_slices=fv.num_replicas,
@@ -404,7 +404,7 @@ def from_flags(cls, fv: flags.FlagValues, **kwargs) -> Config:
404404

405405
def __init__(self, cfg: Config, *, bundler: Bundler):
406406
super().__init__(cfg, bundler=bundler)
407-
cfg: TPUReplicatedJob.Config = self.config
407+
cfg: TPUJobBuilder.Config = self.config
408408
if cfg.output_dir is None:
409409
raise ValueError("cfg.output_dir is required.")
410410
self._tpu_type = infer_tpu_type(cfg.accelerator.instance_type)
@@ -433,7 +433,7 @@ def _build_container(self) -> Nested[Any]:
433433
Returns:
434434
A nested dict corresponding to a k8s Container config.
435435
"""
436-
cfg: TPUReplicatedJob.Config = self.config
436+
cfg: TPUJobBuilder.Config = self.config
437437
system = USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS[self._tpu_type]
438438
volume_mounts = [self._output_volume_mount]
439439

@@ -503,7 +503,7 @@ def _build_uploader_container(
503503
Returns:
504504
A nested dict corresponding to a k8s Container config.
505505
"""
506-
cfg: TPUReplicatedJob.Config = self.config
506+
cfg: TPUJobBuilder.Config = self.config
507507
output_volume_mount = output_volume_mount or self._output_volume_mount
508508
dst = f"{cfg.output_dir}/output/$HOSTNAME/"
509509
interval_s = 60
@@ -538,7 +538,7 @@ def _build_pod(self) -> Nested[Any]:
538538
Returns:
539539
A nested dict corresponding to a k8s Pod template, including the pod metadata and spec.
540540
"""
541-
cfg: TPUReplicatedJob.Config = self.config
541+
cfg: TPUJobBuilder.Config = self.config
542542
system = USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS[self._tpu_type]
543543
annotations, labels, selector, volumes, tolerations = {}, {}, {}, [], []
544544

@@ -727,6 +727,12 @@ def _build_pod(self) -> Nested[Any]:
727727
spec=spec,
728728
)
729729

730+
731+
class TPUReplicatedJob(TPUJobBuilder):
732+
"""Builds a replicated job spec for a generic TPU job to be used with the JobSet API"""
733+
734+
Config = TPUJobBuilder.Config
735+
730736
def __call__(self) -> Sequence[Nested[Any]]:
731737
"""See `BaseReplicatedJob` docstring for details."""
732738
cfg: TPUReplicatedJob.Config = self.config

axlearn/cloud/gcp/lws_utils.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright © 2025 Apple Inc.
2+
3+
"""Utilities for building LeaderWorkerSet specs"""
4+
5+
from typing import Any, Optional, Sequence
6+
7+
from absl import flags
8+
9+
from axlearn.cloud.common.bundler import Bundler
10+
from axlearn.cloud.common.utils import AcceleratorConfig, FlagConfigurable, accelerator_flags
11+
from axlearn.cloud.gcp.config import gcp_settings
12+
from axlearn.cloud.gcp.jobset_utils import TPUJobBuilder
13+
from axlearn.cloud.gcp.system_characteristics import USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS
14+
from axlearn.common.config import REQUIRED, Required, config_class
15+
from axlearn.common.utils import Nested
16+
17+
18+
class BaseLeaderWorkerTemplate(FlagConfigurable):
19+
"""
20+
Common base class for LeaderWorker Templates
21+
"""
22+
23+
@config_class
24+
class Config(FlagConfigurable.Config):
25+
"""
26+
Configures BaseLeaderWorker.
27+
Attributes:
28+
name: Name of the LeaderWorkerSet
29+
command: Command to be executed.
30+
accelerator: Accelerator configuration.
31+
env_vars: Optional env vars to set.
32+
service_account: Optional service account to execute the job as.
33+
output_dir: An optional GCS path to upload LWS outputs to.
34+
"""
35+
36+
name: Required[str] = REQUIRED
37+
# TODO: Change this to be a list of str[], to support different commands
38+
# between leader and workers
39+
command: Required[str] = REQUIRED
40+
accelerator: AcceleratorConfig = AcceleratorConfig()
41+
env_vars: dict[str, str] = {}
42+
service_account: Optional[str] = None
43+
output_dir: Optional[str] = None
44+
45+
@classmethod
46+
def define_flags(cls, fv):
47+
super().define_flags(fv)
48+
common_kwargs = dict(flag_values=fv, allow_override=True)
49+
accelerator_flags(**common_kwargs)
50+
# NOTE: the parent typically sets these flags, so we leave them as None.
51+
flags.DEFINE_string("name", None, "Name of the LWS.", **common_kwargs)
52+
flags.DEFINE_string("command", None, "Command to execute.", **common_kwargs)
53+
flags.DEFINE_multi_string("env", [], "Env var in the format key:value.", **common_kwargs)
54+
flags.DEFINE_string(
55+
"service_account",
56+
None,
57+
"If specified, will run job as the service account.",
58+
**common_kwargs,
59+
)
60+
flags.DEFINE_string(
61+
"output_dir",
62+
None,
63+
"If specified, the directory to store outputs (such as logs).",
64+
**common_kwargs,
65+
)
66+
flags.DEFINE_boolean(
67+
"enable_pre_provisioner", None, "Whether to enable pre-provisioner.", **common_kwargs
68+
)
69+
70+
@classmethod
71+
def from_flags(cls, fv: flags.FlagValues, **kwargs):
72+
cfg: BaseLeaderWorkerTemplate.Config = super().from_flags(fv, **kwargs)
73+
cfg.service_account = cfg.service_account or gcp_settings(
74+
"k8s_service_account", default="default", fv=fv
75+
)
76+
cfg.accelerator.set(instance_type=fv.instance_type, num_replicas=fv.num_replicas)
77+
return cfg
78+
79+
def __init__(self, cfg: Config, *, bundler: Bundler):
80+
super().__init__(cfg)
81+
self._bundler = bundler
82+
83+
def __call__(self) -> Sequence[Nested[Any]]:
84+
"""Builds LeaderWorkerTemplate for the LWS API.
85+
86+
Returns:
87+
A nested dict corresponding to a LeaderWorkerTemplate config.
88+
"""
89+
raise NotImplementedError(type(self))
90+
91+
92+
class TPULeaderWorkerTemplate(TPUJobBuilder):
93+
"""Builds a LeaderWorkerTemplate spec for a generic TPU workload"""
94+
95+
Config = TPUJobBuilder.Config
96+
97+
def __call__(self) -> Sequence[Nested[Any]]:
98+
system = USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS[self._tpu_type]
99+
return dict(
100+
size=system.vms_per_slice,
101+
workerTemplate=self._build_pod(),
102+
)

0 commit comments

Comments
 (0)