Skip to content

Commit 1e41723

Browse files
committed
feat: workload policy label
1 parent e5cf2f3 commit 1e41723

File tree

4 files changed

+46
-15
lines changed

4 files changed

+46
-15
lines changed

src/xpk/commands/workload.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363
get_cpu_affinity,
6464
get_gpu_scheduler,
6565
create_sub_slicing_annotations,
66+
create_placement_policy_label,
67+
is_placement_policy_supported,
6668
)
6769
from ..core.storage import (
6870
GCE_PD_TYPE,
@@ -143,6 +145,7 @@
143145
nodeSelector:
144146
{accelerator_label}
145147
{machine_label}
148+
{placement_policy_label}
146149
{autoprovisioning_args}
147150
priorityClassName: {args.priority}
148151
hostNetwork: true
@@ -272,6 +275,7 @@
272275
terminationGracePeriodSeconds: {args.termination_grace_period_seconds}
273276
priorityClassName: {args.priority}
274277
nodeSelector:
278+
{placement_policy_label}
275279
{autoprovisioning_args}
276280
pathwaysDir: {args.pathways_gcs_location} #This bucket needs to be created in advance.
277281
controller:
@@ -558,6 +562,11 @@ def workload_create(args) -> None:
558562
user_workload=get_user_workload_for_pathways(args, system),
559563
local_queue_name=LOCAL_QUEUE_NAME,
560564
autoprovisioning_args=autoprovisioning_args,
565+
placement_policy_label=(
566+
create_placement_policy_label(system)
567+
if is_placement_policy_supported(system)
568+
else ''
569+
),
561570
)
562571
else:
563572
container, debugging_dashboard_id = get_user_workload_container(
@@ -585,6 +594,11 @@ def workload_create(args) -> None:
585594
create_sub_slicing_annotations(args.sub_slicing_topology)
586595
)
587596
),
597+
placement_policy_label=(
598+
create_placement_policy_label(system)
599+
if is_placement_policy_supported(system)
600+
else ''
601+
),
588602
machine_label=create_machine_label(system.accelerator_type, system),
589603
local_queue_name=LOCAL_QUEUE_NAME,
590604
autoprovisioning_args=autoprovisioning_args,

src/xpk/core/nodepool.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
from typing import List
1818
from ..utils.console import ask_for_user_consent, xpk_print
19-
from ..utils.topology import get_topology_product, is_topology_valid
19+
from ..utils.topology import get_topology_product
20+
from .scheduling import get_placement_policy_name, is_placement_policy_supported
2021
from .capacity import (
2122
AUTOPROVISIONING_CONFIG_VALUE,
2223
H100_MEGA_DEVICE_TYPE,
@@ -258,10 +259,8 @@ def run_gke_node_pool_create_command(
258259
return 1
259260

260261
placement_args = ''
261-
if system.requires_workload_policy and is_topology_valid(system.topology):
262-
placement_policy = (
263-
f'{system.device_type}-{system.topology}-placement-policy'
264-
)
262+
if is_placement_policy_supported(system):
263+
placement_policy = get_placement_policy_name(system)
265264
ensure_resource_policy_exists(placement_policy, args, system.topology)
266265
placement_args = f' --placement-policy={placement_policy}'
267266

src/xpk/core/nodepool_test.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -146,21 +146,23 @@ def mock_nodepool_dependencies(mocker):
146146
)
147147
mocker.patch("xpk.core.nodepool.run_commands", return_value=0)
148148
mocker.patch("xpk.core.nodepool.ask_for_user_consent", return_value=True)
149-
mock_is_topology_valid = mocker.patch("xpk.core.nodepool.is_topology_valid")
149+
mock_is_placement_policy_supported = mocker.patch(
150+
"xpk.core.nodepool.is_placement_policy_supported"
151+
)
150152
mock_ensure_resource_policy = mocker.patch(
151153
"xpk.core.nodepool.ensure_resource_policy_exists"
152154
)
153-
return mock_is_topology_valid, mock_ensure_resource_policy
155+
return mock_is_placement_policy_supported, mock_ensure_resource_policy
154156

155157

156158
def test_placement_policy_created_for_gpu_with_valid_topology(
157159
mocker, mock_nodepool_dependencies
158160
):
159161
"""Tests that placement policy is created for GPUs with a valid topology."""
160-
mock_is_topology_valid, mock_ensure_resource_policy = (
162+
mock_is_placement_policy_supported, mock_ensure_resource_policy = (
161163
mock_nodepool_dependencies
162164
)
163-
mock_is_topology_valid.return_value = True
165+
mock_is_placement_policy_supported.return_value = True
164166
args = mocker.Mock(
165167
tpu_type=None,
166168
device_type="h100-80gb-8",
@@ -188,10 +190,10 @@ def test_placement_policy_not_created_for_gpu_with_invalid_topology(
188190
mocker, mock_nodepool_dependencies
189191
):
190192
"""Tests that placement policy is not created for GPUs with an invalid topology."""
191-
mock_is_topology_valid, mock_ensure_resource_policy = (
193+
mock_is_placement_policy_supported, mock_ensure_resource_policy = (
192194
mock_nodepool_dependencies
193195
)
194-
mock_is_topology_valid.return_value = False
196+
mock_is_placement_policy_supported.return_value = False
195197
args = mocker.Mock(
196198
tpu_type=None,
197199
device_type="h100-80gb-8",
@@ -218,10 +220,10 @@ def test_placement_policy_created_for_tpu7x_with_valid_topology(
218220
mocker, mock_nodepool_dependencies
219221
):
220222
"""Tests that placement policy is created for tpu7x with a valid topology."""
221-
mock_is_topology_valid, mock_ensure_resource_policy = (
223+
mock_is_placement_policy_supported, mock_ensure_resource_policy = (
222224
mock_nodepool_dependencies
223225
)
224-
mock_is_topology_valid.return_value = True
226+
mock_is_placement_policy_supported.return_value = True
225227
args = mocker.Mock(
226228
tpu_type="tpu7x-8",
227229
device_type=None,
@@ -251,10 +253,10 @@ def test_placement_policy_not_created_for_non7x_tpu(
251253
mocker, mock_nodepool_dependencies
252254
):
253255
"""Tests that placement policy is not created for non-tpu7x TPUs."""
254-
mock_is_topology_valid, mock_ensure_resource_policy = (
256+
mock_is_placement_policy_supported, mock_ensure_resource_policy = (
255257
mock_nodepool_dependencies
256258
)
257-
mock_is_topology_valid.return_value = True
259+
mock_is_placement_policy_supported.return_value = False
258260
args = mocker.Mock(
259261
tpu_type="v6e",
260262
device_type=None,

src/xpk/core/scheduling.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""
1616

1717
from ..utils.console import xpk_print
18+
from ..utils.topology import is_topology_valid
1819
from ..utils.execution_context import is_dry_run
1920
from .capacity import AUTOPROVISIONING_CONFIG_MAXIMUM_KEY, AUTOPROVISIONING_CONFIG_VALUE
2021
from .resources import CLUSTER_RESOURCES_CONFIGMAP, get_cluster_configmap
@@ -303,3 +304,18 @@ def create_sub_slicing_annotations(sub_slicing_topology: str) -> list[str]:
303304
),
304305
f'cloud.google.com/gke-tpu-slice-topology: {sub_slicing_topology}',
305306
]
307+
308+
309+
def create_placement_policy_label(system: SystemCharacteristics) -> str:
310+
if system.accelerator_type != AcceleratorType.TPU:
311+
return ''
312+
name = get_placement_policy_name(system)
313+
return f'cloud.google.com/placement-policy-name: {name}'
314+
315+
316+
def get_placement_policy_name(system: SystemCharacteristics) -> str:
317+
return f'{system.device_type}-{system.topology}-placement-policy'
318+
319+
320+
def is_placement_policy_supported(system: SystemCharacteristics) -> bool:
321+
return system.requires_workload_policy and is_topology_valid(system.topology)

0 commit comments

Comments
 (0)