Skip to content

Commit 757fa35

Browse files
Add Jobset controller patching for MTC cluster (#475)
1 parent 4b52b16 commit 757fa35

File tree

2 files changed

+123
-4
lines changed

2 files changed

+123
-4
lines changed

pytype-conf.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ exclude =
1717
src/xpk/core/kueue.py
1818
src/xpk/core/nap.py
1919
src/xpk/core/storage.py
20+
src/xpk/core/mtc.py
2021
src/xpk/core/pathways.py
2122
src/xpk/core/system_characteristics.py
2223
src/xpk/parser

src/xpk/core/mtc.py

Lines changed: 122 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,16 @@
1414
limitations under the License.
1515
"""
1616

17-
from ..utils.console import xpk_exit, xpk_print
17+
import requests
18+
import yaml
19+
20+
from ..core.cluster import JOBSET_VERSION
21+
from ..core.cluster import setup_k8s_env
1822
from ..utils import templates
23+
from ..utils.console import xpk_exit
24+
from ..utils.console import xpk_print
1925
from ..utils.kubectl import apply_kubectl_manifest
20-
from ..core.cluster import setup_k8s_env
26+
2127

2228
MTC_CPC_PATH = "/../templates/mtc-cpc.yaml"
2329

@@ -28,6 +34,17 @@ def create_mtc_cpc(
2834
mtc_toleration_key: str,
2935
mtc_ramdisk_size: str,
3036
) -> dict:
37+
"""Create MTC Checkpoint Configuration.
38+
39+
Args:
40+
mtc_gcs_bucket: GCS bucket for MTC
41+
mtc_machine_type: Machine type for MTC
42+
mtc_toleration_key: Toleration key for MTC
43+
mtc_ramdisk_size: Ramdisk size for MTC
44+
45+
Returns:
46+
MTC Checkpoint Configuration
47+
"""
3148
data = templates.load(MTC_CPC_PATH)
3249

3350
data["spec"]["cloudStorageBucketName"] = mtc_gcs_bucket
@@ -41,10 +58,11 @@ def create_mtc_cpc(
4158

4259

4360
def install_mtc_on_cluster(args, system) -> int:
44-
"""Install MTC on the cluster
61+
"""Install MTC on the cluster.
4562
4663
Args:
4764
args: user provided arguments for running the command.
65+
system: system related information.
4866
4967
Returns:
5068
return code of the command.
@@ -62,16 +80,116 @@ def install_mtc_on_cluster(args, system) -> int:
6280
if args.mtc_toleration_key is None:
6381
args.mtc_toleration_key = "google.com/tpu"
6482

83+
k8s_api_client = setup_k8s_env(args)
84+
jobset_manifest = update_jobset_manifest()
85+
if jobset_manifest is None:
86+
xpk_print(
87+
"Updated jobset manifest is empty, not updating the jobset controller."
88+
)
89+
90+
xpk_print("Applying Jobset with MTC Configuration")
91+
return_code = apply_kubectl_manifest(k8s_api_client, [jobset_manifest])
92+
if return_code != 0:
93+
return return_code
94+
6595
mtc_checkpoint_configuration_crd_data = create_mtc_cpc(
6696
args.mtc_gcs_bucket,
6797
system.gce_machine_type,
6898
args.mtc_toleration_key,
6999
args.mtc_ramdisk_size,
70100
)
71101
xpk_print("Applying MTC Checkpoint Configuration")
72-
k8s_api_client = setup_k8s_env(args)
73102
return_code = apply_kubectl_manifest(
74103
k8s_api_client, [mtc_checkpoint_configuration_crd_data]
75104
)
76105

77106
return return_code
107+
108+
109+
def update_jobset_manifest():
110+
"""Update the jobset manifest to increase the resources for the jobset controller manager.
111+
112+
Returns:
113+
The updated jobset manifest.
114+
"""
115+
manifest_url = f"https://github.com/kubernetes-sigs/jobset/releases/download/{JOBSET_VERSION}/manifests.yaml"
116+
manifest_content = None
117+
# Fetch the manifest content
118+
try:
119+
response = requests.get(manifest_url, timeout=10)
120+
response.raise_for_status() # Raise an exception for HTTP errors
121+
manifest_content = response.text
122+
except requests.exceptions.Timeout as e:
123+
xpk_print(f"Error: Request to {manifest_url} after 10 seconds: {e}")
124+
xpk_exit(1)
125+
except requests.exceptions.RequestException as e:
126+
xpk_print(f"Error fetching manifest from {manifest_url}: {e}")
127+
xpk_exit(1)
128+
129+
if manifest_content is None:
130+
xpk_print("Manifest content not found.")
131+
xpk_exit(1)
132+
133+
# Load all YAML documents from the manifest
134+
yaml_data_list = list(yaml.safe_load_all(manifest_content))
135+
# Iterate through the yaml_data to find the Deployment for
136+
# jobset-controller-manager
137+
update_manifest = False
138+
for yaml_data in yaml_data_list:
139+
if (
140+
yaml_data
141+
and yaml_data.get("apiVersion") == "apps/v1"
142+
and yaml_data.get("kind") == "Deployment"
143+
and yaml_data.get("metadata", {}).get("name")
144+
== "jobset-controller-manager"
145+
):
146+
# Found the Deployment, now modify the resources
147+
containers = yaml_data["spec"]["template"]["spec"]["containers"]
148+
for container in containers:
149+
if container["name"] == "manager":
150+
# Update resource limits and requests
151+
current_cpu_request = (
152+
container["resources"].get("requests", {}).get("cpu", "0m")
153+
)
154+
current_memory_request = (
155+
container["resources"].get("requests", {}).get("memory", "0Mi")
156+
)
157+
current_memory_limit = (
158+
container["resources"].get("limits", {}).get("memory", "0Mi")
159+
)
160+
161+
# Define new values for comparison
162+
new_cpu_request = "1000m"
163+
new_memory_request = "1Gi"
164+
new_memory_limit = "2Gi"
165+
166+
if parse_resource_value(current_cpu_request) < parse_resource_value(
167+
new_cpu_request
168+
):
169+
container["resources"]["requests"]["cpu"] = new_cpu_request
170+
update_manifest = True
171+
if parse_resource_value(
172+
current_memory_request
173+
) < parse_resource_value(new_memory_request):
174+
container["resources"]["requests"]["memory"] = new_memory_request
175+
update_manifest = True
176+
if parse_resource_value(current_memory_limit) < parse_resource_value(
177+
new_memory_limit
178+
):
179+
container["resources"]["limits"]["memory"] = new_memory_limit
180+
update_manifest = True
181+
break
182+
if update_manifest:
183+
xpk_print("Jobset controller updation required.")
184+
return yaml_data
185+
xpk_print("Jobset controller no updation required.")
186+
187+
188+
def parse_resource_value(value) -> int:
189+
if value.endswith("m"):
190+
return int(value[:-1])
191+
if value.endswith("Mi"):
192+
return int(value[:-2])
193+
if value.endswith("Gi"):
194+
return int(value[:-2]) * 1024
195+
return int(value)

0 commit comments

Comments
 (0)