Skip to content

Commit ea5ff77

Browse files
authored
Set network names based on cluster configuration when running workloads on A3 Mega, A3 Ultra and A4 (#457)
* automatically detect cluster network names when running workloads * fix imports after merge
1 parent 3fa16b4 commit ea5ff77

File tree

7 files changed

+55
-69
lines changed

7 files changed

+55
-69
lines changed

src/xpk/commands/kjob_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def add_gpu_networking_annotations_to_command(args, cmd: str) -> str:
3535
elif gpu_type == H200_DEVICE_TYPE:
3636
annotations = get_a3ultra_pod_template_annotations(args)
3737
elif gpu_type == B200_DEVICE_TYPE:
38-
annotations = get_a4_pod_template_annotations()
38+
annotations = get_a4_pod_template_annotations(args)
3939
else:
4040
annotations = []
4141

src/xpk/commands/workload.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,6 @@
1414
limitations under the License.
1515
"""
1616

17-
from ..core.blueprint.blueprint_generator import (
18-
get_subnetworks_for_a3mega,
19-
get_subnetworks_for_a3ultra,
20-
get_subnetworks_for_a4,
21-
)
2217
from ..core.cluster import (
2318
XPK_SA,
2419
create_xpk_k8s_service_account,
@@ -43,6 +38,7 @@
4338
get_autoprovisioning_node_selector_args,
4439
is_autoprovisioning_enabled,
4540
)
41+
from ..core.network import get_cluster_subnetworks
4642
from ..core.pathways import (
4743
append_custom_colocated_python_sidecar,
4844
append_custom_pathways_proxy_server,
@@ -460,16 +456,13 @@ def workload_create(args) -> None:
460456
pod_failure_policy=pod_failure_policy,
461457
)
462458

459+
sub_networks = get_cluster_subnetworks(args)
463460
if args.device_type == cluster_gcluster.a3mega_device_type:
464-
sub_networks = get_subnetworks_for_a3mega(args.cluster)
465461
yml_string = tcpxo_decorator.decorate_jobset(yml_string, sub_networks)
466-
467-
if args.device_type == cluster_gcluster.a3ultra_device_type:
468-
sub_networks = get_subnetworks_for_a3ultra(args.cluster)
469-
yml_string = rdma_decorator.decorate_jobset(yml_string, sub_networks)
470-
471-
if args.device_type == cluster_gcluster.a4_device_type:
472-
sub_networks = get_subnetworks_for_a4()
462+
elif args.device_type in [
463+
cluster_gcluster.a3ultra_device_type,
464+
cluster_gcluster.a4_device_type,
465+
]:
473466
yml_string = rdma_decorator.decorate_jobset(yml_string, sub_networks)
474467

475468
if all_storages:

src/xpk/core/blueprint/blueprint_generator.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,20 +52,6 @@
5252
cluster_toolkit_version = "v1.48.0"
5353

5454

55-
def get_subnetworks_for_a3mega(cluster_name: str) -> list[str]:
56-
return [f"{cluster_name}-gpunet-{i}-subnet" for i in range(8)]
57-
58-
59-
def get_subnetworks_for_a3ultra(cluster_name: str) -> list[str]:
60-
return [f"{cluster_name}-sub-1"] + [
61-
f"{cluster_name}-rdma-sub-{i}" for i in range(8)
62-
]
63-
64-
65-
def get_subnetworks_for_a4() -> list[str]:
66-
return ["gvnic-1"] + [f"rdma-{i}" for i in range(8)]
67-
68-
6955
class BlueprintGeneratorOutput:
7056
"""BlueprintGeneratorOutput is a class containing fields with output blueprint file path and path to blueprint dependencies.
7157
Atributes:

src/xpk/core/kjob.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,9 @@
2222
from kubernetes.client import ApiClient
2323
from kubernetes.client.rest import ApiException
2424

25-
from ..core.blueprint.blueprint_generator import (
26-
get_subnetworks_for_a3mega,
27-
get_subnetworks_for_a3ultra,
28-
get_subnetworks_for_a4,
29-
)
30-
from ..core.capacity import (
31-
H100_DEVICE_TYPE,
32-
H100_MEGA_DEVICE_TYPE,
33-
H200_DEVICE_TYPE,
34-
)
35-
from ..core.storage import GCS_FUSE_ANNOTATIONS, PARALLELSTORE_ANNOTATIONS
36-
from ..core.workload_decorators import (
37-
rdma_decorator,
38-
tcpx_decorator,
39-
tcpxo_decorator,
40-
)
4125
from ..utils import templates
4226
from ..utils.console import xpk_exit, xpk_print
27+
from .capacity import H100_DEVICE_TYPE, H100_MEGA_DEVICE_TYPE, H200_DEVICE_TYPE
4328
from .cluster import DEFAULT_NAMESPACE, XPK_SA, setup_k8s_env
4429
from .commands import (
4530
run_command_for_value,
@@ -54,16 +39,24 @@
5439
KJOB_SHELL_WORKING_DIRECTORY,
5540
XpkConfig,
5641
)
42+
from .network import get_cluster_subnetworks
5743
from .resources import (
5844
AcceleratorType,
5945
SystemCharacteristics,
6046
get_cluster_system_characteristics,
6147
)
6248
from .storage import (
49+
GCS_FUSE_ANNOTATIONS,
50+
PARALLELSTORE_ANNOTATIONS,
6351
get_auto_mount_gcsfuse_storages,
6452
get_auto_mount_parallelstore_storages,
6553
get_auto_mount_storages,
6654
)
55+
from .workload_decorators import (
56+
rdma_decorator,
57+
tcpx_decorator,
58+
tcpxo_decorator,
59+
)
6760
from .workload_decorators.tcpxo_decorator import get_tcpxo_deamon_entry
6861

6962
KJOB_API_GROUP_NAME = "kjobctl.x-k8s.io"
@@ -176,8 +169,8 @@ class PodTemplateDefaults(Enum):
176169
default_interface_annotation = "networking.gke.io/default-interface=eth0"
177170

178171

179-
def get_a4_pod_template_annotations() -> tuple[str, str]:
180-
sub_networks = get_subnetworks_for_a4()
172+
def get_a4_pod_template_annotations(args) -> tuple[str, str]:
173+
sub_networks = get_cluster_subnetworks(args)
181174
interfaces_key, interfaces_value = rdma_decorator.get_interfaces_entry(
182175
sub_networks
183176
)
@@ -189,7 +182,7 @@ def get_a4_pod_template_annotations() -> tuple[str, str]:
189182

190183

191184
def get_a3ultra_pod_template_annotations(args: Namespace) -> tuple[str, str]:
192-
sub_networks = get_subnetworks_for_a3ultra(args.cluster)
185+
sub_networks = get_cluster_subnetworks(args)
193186
interfaces_key, interfaces_value = rdma_decorator.get_interfaces_entry(
194187
sub_networks
195188
)
@@ -204,7 +197,7 @@ def get_a3mega_pod_template_annotations(
204197
args: Namespace,
205198
) -> tuple[str, str, str]:
206199
"""Adds or updates annotations in the Pod template."""
207-
sub_networks = get_subnetworks_for_a3mega(args.cluster)
200+
sub_networks = get_cluster_subnetworks(args)
208201
tcpxo_deamon_key, tcpxo_deamon_paths = get_tcpxo_deamon_entry()
209202
interfaces_key, interfaces_value = tcpxo_decorator.get_interfaces_entry(
210203
sub_networks

src/xpk/core/network.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
limitations under the License.
1515
"""
1616

17-
from ..utils.console import xpk_print
17+
from ..utils.console import xpk_exit, xpk_print
1818
from ..utils.file import write_tmp_file
1919
from .commands import run_command_for_value, run_command_with_updates
2020
from .gcloud_context import zone_to_region
@@ -235,6 +235,28 @@ def create_cluster_network_config(args) -> int:
235235
return 0
236236

237237

238+
def get_cluster_subnetworks(args) -> list[str]:
239+
"""Gets the list of cluster networks.
240+
241+
Args:
242+
args: user provided arguments for running the command.
243+
244+
Returns:
245+
list[str]: list of cluster networks
246+
"""
247+
command = 'kubectl get GKENetworkParamSet'
248+
return_code, stdout = run_command_for_value(
249+
command, 'Get Cluster Networks', args
250+
)
251+
if return_code != 0:
252+
xpk_print('GKE Cluster Get NetworkParamSet failed')
253+
xpk_exit(return_code)
254+
255+
networks = [line.split()[0] for line in stdout.splitlines()][1:]
256+
257+
return networks
258+
259+
238260
def set_up_cluster_network_for_a3(args) -> int:
239261
"""Set up GKE Cluster networks, subnets and firewall rules for A3.
240262
Note: there are 4 NICs for GPU-GPU bw and 1 NIC for host in an A3 node.

src/xpk/core/workload_decorators/rdma_decorator.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,12 @@ def decorate_jobset(jobset_manifest_str: str, sub_networks: list[str]) -> str:
6868

6969

7070
def get_interfaces_entry(sub_networks: list[str]) -> tuple[str, str]:
71-
interfaces = [
72-
'[',
73-
' {"interfaceName":"eth0","network":"default"},',
74-
*[
75-
f' {{"interfaceName":"eth{i + 1}","network":"{sub_networks[i]}"}}{"," if i<8 else ""}'
76-
for i in range(9)
77-
],
78-
']',
79-
]
80-
return 'networking.gke.io/interfaces', literal_string('\n'.join(interfaces))
71+
entries = ',\n'.join([
72+
f' {{"interfaceName":"eth{i}","network":"{network}"}}'
73+
for i, network in enumerate(sub_networks)
74+
])
75+
interfaces = f'[\n{entries}\n]'
76+
return 'networking.gke.io/interfaces', literal_string(interfaces)
8177

8278

8379
def add_annotations(job_manifest: dict, sub_networks: list[str]):

src/xpk/core/workload_decorators/tcpxo_decorator.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,12 @@ def decorate_jobset(jobset_manifest_str: str, sub_networks: list[str]) -> str:
7777

7878

7979
def get_interfaces_entry(sub_networks: list[str]) -> tuple[str, str]:
80-
interfaces = [
81-
'[',
82-
' {"interfaceName":"eth0","network":"default"},',
83-
*[
84-
f' {{"interfaceName":"eth{i + 1}","network":"{sub_networks[i]}"}}{"," if i<7 else ""}'
85-
for i in range(8)
86-
],
87-
']',
88-
]
89-
return 'networking.gke.io/interfaces', literal_string('\n'.join(interfaces))
80+
entries = ',\n'.join([
81+
f' {{"interfaceName":"eth{i}","network":"{network}"}}'
82+
for i, network in enumerate(sub_networks)
83+
])
84+
interfaces = f'[\n{entries}\n]'
85+
return 'networking.gke.io/interfaces', literal_string(interfaces)
9086

9187

9288
def get_tcpxo_deamon_entry() -> tuple[str, str]:

0 commit comments

Comments
 (0)