Skip to content

Commit 04d5fd7

Browse files
Alexandre Jameschanglan
authored andcommitted
Topology Assignment for Leader Worker Set
GitOrigin-RevId: 726c4815749fca66b2214ecfb8ef4e1edeef8855
1 parent 819e26a commit 04d5fd7

File tree

4 files changed

+296
-37
lines changed

4 files changed

+296
-37
lines changed

axlearn/cloud/gcp/job.py

Lines changed: 87 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,42 @@ class _ServiceType(enum.Enum):
6060
EXTERNAL_NAME = "ExternalName"
6161

6262

63+
def get_topology_assignment() -> Optional[list[list[str]]]:
64+
"""Retrieves TPU topology assignments from the environment variable.
65+
66+
When TPU slice auto-provisioning is enabled, Bastion passes topology assignments
67+
through an environment variable. These assignments specify which TPU slices should be
68+
used for the job, enabling precise control over TPU resource allocation.
69+
70+
Example topology assignment:
71+
[["sub-block-id", "sub-block-id"]]
72+
73+
This is the assignment for a job asking for tpu-7x-256, that needs 128 chips, using
74+
2 sub-blocks (64 chips per sub-block). This job will run on a TPU slice formed by
75+
2 sub-blocks. Each inner array represents the TPU slice info for a job's replica.
76+
77+
Returns:
78+
A list of lists of strings representing topology assignments, where each inner list
79+
contains slice identifiers for a particular job replica. Returns None if the
80+
environment variable is not set or if parsing fails.
81+
"""
82+
topology_assignments_env = os.environ.get(BASTION_JOB_TOPOLOGY_ASSIGNMENT_ENV_VAR)
83+
if not topology_assignments_env:
84+
logging.info("No %s environment variable set.", BASTION_JOB_TOPOLOGY_ASSIGNMENT_ENV_VAR)
85+
return None
86+
87+
try:
88+
return json.loads(topology_assignments_env)
89+
except json.JSONDecodeError as e:
90+
logging.warning(
91+
"Failed to parse topology assignments from env var %s, value: %s, error: %s",
92+
BASTION_JOB_TOPOLOGY_ASSIGNMENT_ENV_VAR,
93+
topology_assignments_env,
94+
e,
95+
)
96+
return None
97+
98+
6399
class GCPJob(Job):
64100
"""Base GCP Job definition."""
65101

@@ -173,41 +209,6 @@ def _delete(self):
173209
# fully blocking; after the call returns there can be a delay before everything is deleted.
174210
delete_k8s_jobset(cfg.name, namespace=cfg.namespace)
175211

176-
def _get_topology_assignment(self) -> Optional[list[list[str]]]:
177-
"""Retrieves TPU topology assignments from the environment variable.
178-
179-
When TPU slice auto-provisioning is enabled, Bastion passes topology assignments
180-
through an environment variable. These assignments specify which TPU slices should be
181-
used for the job, enabling precise control over TPU resource allocation.
182-
183-
Example topology assignment:
184-
[["sub-block-id", "sub-block-id"]]
185-
186-
This is the assignment for a job asking for tpu-7x-256, that needs 128 chips, using
187-
2 sub-blocks (64 chips per sub-block). This job will run on a TPU slice formed by
188-
2 sub-blocks. Each inner array represents the TPU slice info for a job's replica.
189-
190-
Returns:
191-
A list of lists of strings representing topology assignments, where each inner list
192-
contains slice identifiers for a particular job replica. Returns None if the
193-
environment variable is not set or if parsing fails.
194-
"""
195-
topology_assignments_env = os.environ.get(BASTION_JOB_TOPOLOGY_ASSIGNMENT_ENV_VAR)
196-
if not topology_assignments_env:
197-
logging.info("No %s environment variable set.", BASTION_JOB_TOPOLOGY_ASSIGNMENT_ENV_VAR)
198-
return None
199-
200-
try:
201-
return json.loads(topology_assignments_env)
202-
except json.JSONDecodeError as e:
203-
logging.warning(
204-
"Failed to parse topology assignments from env var %s, value: %s, error: %s",
205-
BASTION_JOB_TOPOLOGY_ASSIGNMENT_ENV_VAR,
206-
topology_assignments_env,
207-
e,
208-
)
209-
return None
210-
211212
def _lookup_system_by_node_selectors(
212213
self, node_selector: dict[str, str]
213214
) -> Optional[tuple[str, _SystemCharacteristics]]:
@@ -382,7 +383,7 @@ def _build_jobset(self) -> Nested[Any]:
382383
# Bastion passes the job metadata to the runner through env vars
383384
# If the job has topology assigned, its also in the env var
384385
# Try to parse the env var and get the topology assignments.
385-
topology_assignment = self._get_topology_assignment()
386+
topology_assignment = get_topology_assignment()
386387
if cfg.enable_tpu_slice_auto_provisioning and topology_assignment:
387388
slice_selection_dict = self._get_tpu_replicated_job_topology_selection(
388389
replicated_jobs, topology_assignment
@@ -563,6 +564,7 @@ class Config(GCPJob.Config):
563564
gke_gateway_route: bool = False
564565
http_route: Optional[LWSHTTPRoute.Config] = None
565566
health_check_policy: Optional[LWSHealthCheckPolicy.Config] = None
567+
enable_tpu_slice_auto_provisioning: Optional[bool] = None
566568

567569
@classmethod
568570
def set_defaults(cls, fv):
@@ -635,6 +637,12 @@ def define_flags(cls, fv: flags.FlagValues):
635637
"Enable gke_gateway_route with notary-proxy sidecars for direct gateway routing",
636638
**common_kwargs,
637639
)
640+
flags.DEFINE_boolean(
641+
"enable_tpu_slice_auto_provisioning",
642+
None,
643+
"Auto provision TPU slices based on the topology assignment.",
644+
**common_kwargs,
645+
)
638646

639647
@classmethod
640648
def from_flags(cls, fv: flags.FlagValues, **kwargs):
@@ -653,6 +661,15 @@ def __init__(self, cfg: Config, *, bundler: BaseDockerBundler):
653661
super().__init__(cfg)
654662
cfg: GKELeaderWorkerSet.Config = self.config
655663
self._bundler = bundler
664+
665+
# Pass enable_tpu_slice_auto_provisioning from GKEJob to the builder
666+
builder_cfg = cfg.builder
667+
if (
668+
hasattr(builder_cfg, "enable_tpu_slice_auto_provisioning")
669+
and cfg.enable_tpu_slice_auto_provisioning is not None
670+
):
671+
builder_cfg.enable_tpu_slice_auto_provisioning = cfg.enable_tpu_slice_auto_provisioning
672+
656673
# This instantiatees a builder for constructing replicated job specs, which will be managed
657674
# together under the leaderworkerset represented by this class.
658675
# Note the distinction from bundlers, which are responsible for bundling any code assets
@@ -683,9 +700,42 @@ def _build_leaderworkerset(self) -> Nested[Any]:
683700
"""
684701
cfg: GKELeaderWorkerSet.Config = self.config
685702
annotations = maybe_instantiate(cfg.annotations or {})
703+
labels = {}
704+
705+
# If the topology is set and slice auto provisioning is configured
706+
# set the necessary annotations
707+
topology_assignment = get_topology_assignment()
708+
if cfg.enable_tpu_slice_auto_provisioning and topology_assignment:
709+
# Add TPU slice selection
710+
logging.info("Adding slice selection: %s to leader worker set", topology_assignment)
711+
712+
# Note, we use async here rather than the jobset sync. Async will immediatly create
713+
# the pods before the slice has been created. Once sync is supported for leader worker
714+
# set we should consider switching.
715+
labels["tpu-provisioner.cloud.google.com/slice-autoprovisioning"] = "async"
716+
717+
# For Leader worker sets, we only support topology assignments to workers.
718+
# The format of the topology assignments (list of subblock groups) is what
719+
# is expected by the TPU provisioner.
720+
annotations.update(
721+
{
722+
"tpu-provisioner.cloud.google.com/slice-selection": json.dumps(
723+
{
724+
"workers": topology_assignment,
725+
}
726+
)
727+
}
728+
)
729+
730+
# Remove exclusive topology annotation, the tpu provisioner will ensure replica
731+
# affinity by injecting slice based node selectors, so we don't need to use
732+
# the exclusive topology annotations
733+
exclusive_topology_annotation = exclusive_topology_annotations_leaderworkerset()
734+
for key in exclusive_topology_annotation:
735+
annotations.pop(key, None)
686736

687737
return dict(
688-
metadata=dict(name=cfg.name, annotations=annotations),
738+
metadata=dict(name=cfg.name, annotations=annotations, labels=labels),
689739
spec=dict(
690740
replicas=cfg.num_replicas,
691741
leaderWorkerTemplate=self._builder(),

axlearn/cloud/gcp/job_test.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""Tests jobs by launching commands on TPUs/VMs."""
44
# pylint: disable=protected-access
55

6+
import json
67
from typing import Optional, cast
78
from unittest import mock
89

@@ -543,3 +544,113 @@ def test_delete(self):
543544
gke_job = cfg.instantiate(bundler=mock.create_autospec(Bundler))
544545
gke_job._delete() # pylint: disable=protected-access
545546
mock_delete.assert_called()
547+
548+
@parameterized.parameters(
549+
# Test when auto provisioning is enabled with topology assignment
550+
dict(
551+
enable_tpu_slice_auto_provisioning=True,
552+
topology_assignment=[["subblock-1", "subblock-2"]],
553+
expect_label=True,
554+
expect_annotation=True,
555+
),
556+
# Test when auto provisioning is disabled
557+
dict(
558+
enable_tpu_slice_auto_provisioning=False,
559+
topology_assignment=[["subblock-1", "subblock-2"]],
560+
expect_label=False,
561+
expect_annotation=False,
562+
),
563+
# Test when auto provisioning is None (not set)
564+
dict(
565+
enable_tpu_slice_auto_provisioning=None,
566+
topology_assignment=[["subblock-1", "subblock-2"]],
567+
expect_label=False,
568+
expect_annotation=False,
569+
),
570+
# Test when auto provisioning is enabled but no topology assignment
571+
dict(
572+
enable_tpu_slice_auto_provisioning=True,
573+
topology_assignment=None,
574+
expect_label=False,
575+
expect_annotation=False,
576+
),
577+
)
578+
def test_build_leaderworkerset(
579+
self,
580+
enable_tpu_slice_auto_provisioning,
581+
topology_assignment,
582+
expect_label,
583+
expect_annotation,
584+
):
585+
"""Test _build_leaderworkerset with enable_tpu_slice_auto_provisioning."""
586+
cfg, bundler_cfg = self._job_config(
587+
command="test-command",
588+
bundler_cls=CloudBuildBundler,
589+
enable_tpu_slice_auto_provisioning=enable_tpu_slice_auto_provisioning,
590+
)
591+
592+
# Mock the builder to return a simple leader worker template
593+
mock_leader_worker_template = {
594+
"size": 8,
595+
"workerTemplate": {
596+
"metadata": {"labels": {"test-label": "test-value"}},
597+
"spec": {"containers": []},
598+
},
599+
}
600+
601+
# Create a mock builder that returns our mock template
602+
mock_builder = mock.Mock()
603+
mock_builder.return_value = mock_leader_worker_template
604+
605+
# Create the GKE job instance first
606+
gke_job = cfg.instantiate(bundler=bundler_cfg.instantiate())
607+
608+
# Replace the builder with our mock (this is what we're testing)
609+
gke_job._builder = mock_builder
610+
611+
# Mock get_topology_assignment
612+
with mock.patch(
613+
f"{job.__name__}.get_topology_assignment",
614+
return_value=topology_assignment,
615+
):
616+
# Build the leaderworkerset
617+
lws_spec = gke_job._build_leaderworkerset()
618+
619+
# Check metadata
620+
self.assertIn("metadata", lws_spec)
621+
self.assertIn("name", lws_spec["metadata"])
622+
self.assertEqual(cfg.name, lws_spec["metadata"]["name"])
623+
624+
# Check labels
625+
labels = lws_spec["metadata"].get("labels", {})
626+
slice_auto_provisioning_label = (
627+
"tpu-provisioner.cloud.google.com/slice-autoprovisioning"
628+
)
629+
if expect_label:
630+
self.assertIn(slice_auto_provisioning_label, labels)
631+
self.assertEqual("async", labels[slice_auto_provisioning_label])
632+
else:
633+
self.assertNotIn(slice_auto_provisioning_label, labels)
634+
635+
# Check annotations
636+
annotations = lws_spec["metadata"].get("annotations", {})
637+
slice_selection_annotation = "tpu-provisioner.cloud.google.com/slice-selection"
638+
if expect_annotation:
639+
self.assertIn(slice_selection_annotation, annotations)
640+
slice_selection = json.loads(annotations[slice_selection_annotation])
641+
self.assertIn("workers", slice_selection)
642+
self.assertEqual(topology_assignment, slice_selection["workers"])
643+
else:
644+
self.assertNotIn(slice_selection_annotation, annotations)
645+
646+
# Verify exclusive topology annotations are removed when auto provisioning
647+
if expect_annotation:
648+
self.assertNotIn(
649+
"leaderworkerset.sigs.k8s.io/subgroup-exclusive-topology",
650+
annotations,
651+
)
652+
653+
# Check spec
654+
self.assertIn("spec", lws_spec)
655+
self.assertIn("replicas", lws_spec["spec"])
656+
self.assertIn("leaderWorkerTemplate", lws_spec["spec"])

axlearn/cloud/gcp/lws_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,20 @@ class TPULeaderWorkerTemplate(TPUJobBuilder):
9999

100100
Config = TPUJobBuilder.Config
101101

102+
def _build_pod(self) -> Nested[Any]:
103+
cfg: TPUJobBuilder.Config = self.config
104+
105+
# Add inject slice selector for slice auto provisioned jobs
106+
pod = super()._build_pod()
107+
if cfg.enable_tpu_slice_auto_provisioning:
108+
pod["metadata"]["labels"].update(
109+
{
110+
"tpu-provisioner.cloud.google.com/inject-slice-selector": "true",
111+
}
112+
)
113+
114+
return pod
115+
102116
def __call__(self) -> Sequence[Nested[Any]]:
103117
system = USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS[self._tpu_type]
104118
return dict( # pytype: disable=bad-return-type

0 commit comments

Comments
 (0)