@@ -60,6 +60,42 @@ class _ServiceType(enum.Enum):
6060 EXTERNAL_NAME = "ExternalName"
6161
6262
63+ def get_topology_assignment () -> Optional [list [list [str ]]]:
64+ """Retrieves TPU topology assignments from the environment variable.
65+
66+ When TPU slice auto-provisioning is enabled, Bastion passes topology assignments
67+ through an environment variable. These assignments specify which TPU slices should be
68+ used for the job, enabling precise control over TPU resource allocation.
69+
70+ Example topology assignment:
71+ [["sub-block-id", "sub-block-id"]]
72+
73+ This is the assignment for a job asking for tpu-7x-256, that needs 128 chips, using
74+ 2 sub-blocks (64 chips per sub-block). This job will run on a TPU slice formed by
75+ 2 sub-blocks. Each inner array represents the TPU slice info for a job's replica.
76+
77+ Returns:
78+ A list of lists of strings representing topology assignments, where each inner list
79+ contains slice identifiers for a particular job replica. Returns None if the
80+ environment variable is not set or if parsing fails.
81+ """
82+ topology_assignments_env = os .environ .get (BASTION_JOB_TOPOLOGY_ASSIGNMENT_ENV_VAR )
83+ if not topology_assignments_env :
84+ logging .info ("No %s environment variable set." , BASTION_JOB_TOPOLOGY_ASSIGNMENT_ENV_VAR )
85+ return None
86+
87+ try :
88+ return json .loads (topology_assignments_env )
89+ except json .JSONDecodeError as e :
90+ logging .warning (
91+ "Failed to parse topology assignments from env var %s, value: %s, error: %s" ,
92+ BASTION_JOB_TOPOLOGY_ASSIGNMENT_ENV_VAR ,
93+ topology_assignments_env ,
94+ e ,
95+ )
96+ return None
97+
98+
6399class GCPJob (Job ):
64100 """Base GCP Job definition."""
65101
@@ -173,41 +209,6 @@ def _delete(self):
173209 # fully blocking; after the call returns there can be a delay before everything is deleted.
174210 delete_k8s_jobset (cfg .name , namespace = cfg .namespace )
175211
176- def _get_topology_assignment (self ) -> Optional [list [list [str ]]]:
177- """Retrieves TPU topology assignments from the environment variable.
178-
179- When TPU slice auto-provisioning is enabled, Bastion passes topology assignments
180- through an environment variable. These assignments specify which TPU slices should be
181- used for the job, enabling precise control over TPU resource allocation.
182-
183- Example topology assignment:
184- [["sub-block-id", "sub-block-id"]]
185-
186- This is the assignment for a job asking for tpu-7x-256, that needs 128 chips, using
187- 2 sub-blocks (64 chips per sub-block). This job will run on a TPU slice formed by
188- 2 sub-blocks. Each inner array represents the TPU slice info for a job's replica.
189-
190- Returns:
191- A list of lists of strings representing topology assignments, where each inner list
192- contains slice identifiers for a particular job replica. Returns None if the
193- environment variable is not set or if parsing fails.
194- """
195- topology_assignments_env = os .environ .get (BASTION_JOB_TOPOLOGY_ASSIGNMENT_ENV_VAR )
196- if not topology_assignments_env :
197- logging .info ("No %s environment variable set." , BASTION_JOB_TOPOLOGY_ASSIGNMENT_ENV_VAR )
198- return None
199-
200- try :
201- return json .loads (topology_assignments_env )
202- except json .JSONDecodeError as e :
203- logging .warning (
204- "Failed to parse topology assignments from env var %s, value: %s, error: %s" ,
205- BASTION_JOB_TOPOLOGY_ASSIGNMENT_ENV_VAR ,
206- topology_assignments_env ,
207- e ,
208- )
209- return None
210-
211212 def _lookup_system_by_node_selectors (
212213 self , node_selector : dict [str , str ]
213214 ) -> Optional [tuple [str , _SystemCharacteristics ]]:
@@ -382,7 +383,7 @@ def _build_jobset(self) -> Nested[Any]:
382383 # Bastion passes the job metadata to the runner through env vars
383384 # If the job has topology assigned, its also in the env var
384385 # Try to parse the env var and get the topology assignments.
385- topology_assignment = self . _get_topology_assignment ()
386+ topology_assignment = get_topology_assignment ()
386387 if cfg .enable_tpu_slice_auto_provisioning and topology_assignment :
387388 slice_selection_dict = self ._get_tpu_replicated_job_topology_selection (
388389 replicated_jobs , topology_assignment
@@ -563,6 +564,7 @@ class Config(GCPJob.Config):
563564 gke_gateway_route : bool = False
564565 http_route : Optional [LWSHTTPRoute .Config ] = None
565566 health_check_policy : Optional [LWSHealthCheckPolicy .Config ] = None
567+ enable_tpu_slice_auto_provisioning : Optional [bool ] = None
566568
567569 @classmethod
568570 def set_defaults (cls , fv ):
@@ -635,6 +637,12 @@ def define_flags(cls, fv: flags.FlagValues):
635637 "Enable gke_gateway_route with notary-proxy sidecars for direct gateway routing" ,
636638 ** common_kwargs ,
637639 )
640+ flags .DEFINE_boolean (
641+ "enable_tpu_slice_auto_provisioning" ,
642+ None ,
643+ "Auto provision TPU slices based on the topology assignment." ,
644+ ** common_kwargs ,
645+ )
638646
639647 @classmethod
640648 def from_flags (cls , fv : flags .FlagValues , ** kwargs ):
@@ -653,6 +661,15 @@ def __init__(self, cfg: Config, *, bundler: BaseDockerBundler):
653661 super ().__init__ (cfg )
654662 cfg : GKELeaderWorkerSet .Config = self .config
655663 self ._bundler = bundler
664+
665+ # Pass enable_tpu_slice_auto_provisioning from GKEJob to the builder
666+ builder_cfg = cfg .builder
667+ if (
668+ hasattr (builder_cfg , "enable_tpu_slice_auto_provisioning" )
669+ and cfg .enable_tpu_slice_auto_provisioning is not None
670+ ):
671+ builder_cfg .enable_tpu_slice_auto_provisioning = cfg .enable_tpu_slice_auto_provisioning
672+
656673 # This instantiatees a builder for constructing replicated job specs, which will be managed
657674 # together under the leaderworkerset represented by this class.
658675 # Note the distinction from bundlers, which are responsible for bundling any code assets
@@ -683,9 +700,42 @@ def _build_leaderworkerset(self) -> Nested[Any]:
683700 """
684701 cfg : GKELeaderWorkerSet .Config = self .config
685702 annotations = maybe_instantiate (cfg .annotations or {})
703+ labels = {}
704+
705+ # If the topology is set and slice auto provisioning is configured
706+ # set the necessary annotations
707+ topology_assignment = get_topology_assignment ()
708+ if cfg .enable_tpu_slice_auto_provisioning and topology_assignment :
709+ # Add TPU slice selection
710+ logging .info ("Adding slice selection: %s to leader worker set" , topology_assignment )
711+
712+ # Note, we use async here rather than the jobset sync. Async will immediatly create
713+ # the pods before the slice has been created. Once sync is supported for leader worker
714+ # set we should consider switching.
715+ labels ["tpu-provisioner.cloud.google.com/slice-autoprovisioning" ] = "async"
716+
717+ # For Leader worker sets, we only support topology assignments to workers.
718+ # The format of the topology assignments (list of subblock groups) is what
719+ # is expected by the TPU provisioner.
720+ annotations .update (
721+ {
722+ "tpu-provisioner.cloud.google.com/slice-selection" : json .dumps (
723+ {
724+ "workers" : topology_assignment ,
725+ }
726+ )
727+ }
728+ )
729+
730+ # Remove exclusive topology annotation, the tpu provisioner will ensure replica
731+ # affinity by injecting slice based node selectors, so we don't need to use
732+ # the exclusive topology annotations
733+ exclusive_topology_annotation = exclusive_topology_annotations_leaderworkerset ()
734+ for key in exclusive_topology_annotation :
735+ annotations .pop (key , None )
686736
687737 return dict (
688- metadata = dict (name = cfg .name , annotations = annotations ),
738+ metadata = dict (name = cfg .name , annotations = annotations , labels = labels ),
689739 spec = dict (
690740 replicas = cfg .num_replicas ,
691741 leaderWorkerTemplate = self ._builder (),
0 commit comments