Skip to content

Commit 50d3e95

Browse files
committed
improvements to rayjobs
1 parent 3e692cc commit 50d3e95

File tree

5 files changed

+1141
-201
lines changed

5 files changed

+1141
-201
lines changed
Lines changed: 392 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,392 @@
1+
# Copyright 2025 IBM, Red Hat
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
RayCluster spec builder specifically for RayJobs.
17+
18+
This module builds the rayClusterSpec portion of a RayJob CR, which defines
19+
how the Ray cluster should be created when the job runs.
20+
"""
21+
22+
import copy
23+
import logging
24+
from typing import Dict, Any, Union
25+
26+
from kubernetes.client import (
27+
V1ObjectMeta,
28+
V1Container,
29+
V1ContainerPort,
30+
V1Lifecycle,
31+
V1ExecAction,
32+
V1LifecycleHandler,
33+
V1EnvVar,
34+
V1PodTemplateSpec,
35+
V1PodSpec,
36+
V1ResourceRequirements,
37+
V1Volume,
38+
V1VolumeMount,
39+
V1ConfigMapVolumeSource,
40+
V1KeyToPath,
41+
)
42+
43+
from codeflare_sdk.ray.rayjobs.config import RayJobClusterConfig
44+
45+
from ...common.utils.constants import CUDA_RUNTIME_IMAGE
46+
47+
logger = logging.getLogger(__name__)
48+
49+
# Default volume mounts for CA certificates
50+
DEFAULT_VOLUME_MOUNTS = [
51+
V1VolumeMount(
52+
mount_path="/etc/pki/tls/certs/odh-trusted-ca-bundle.crt",
53+
name="odh-trusted-ca-cert",
54+
sub_path="odh-trusted-ca-bundle.crt",
55+
),
56+
V1VolumeMount(
57+
mount_path="/etc/ssl/certs/odh-trusted-ca-bundle.crt",
58+
name="odh-trusted-ca-cert",
59+
sub_path="odh-trusted-ca-bundle.crt",
60+
),
61+
V1VolumeMount(
62+
mount_path="/etc/pki/tls/certs/odh-ca-bundle.crt",
63+
name="odh-ca-cert",
64+
sub_path="odh-ca-bundle.crt",
65+
),
66+
V1VolumeMount(
67+
mount_path="/etc/ssl/certs/odh-ca-bundle.crt",
68+
name="odh-ca-cert",
69+
sub_path="odh-ca-bundle.crt",
70+
),
71+
]
72+
73+
# Default volumes for CA certificates
74+
DEFAULT_VOLUMES = [
75+
V1Volume(
76+
name="odh-trusted-ca-cert",
77+
config_map=V1ConfigMapVolumeSource(
78+
name="odh-trusted-ca-bundle",
79+
items=[V1KeyToPath(key="ca-bundle.crt", path="odh-trusted-ca-bundle.crt")],
80+
optional=True,
81+
),
82+
),
83+
V1Volume(
84+
name="odh-ca-cert",
85+
config_map=V1ConfigMapVolumeSource(
86+
name="odh-trusted-ca-bundle",
87+
items=[V1KeyToPath(key="odh-ca-bundle.crt", path="odh-ca-bundle.crt")],
88+
optional=True,
89+
),
90+
),
91+
]
92+
93+
94+
def build_ray_cluster_spec(
95+
cluster_config: RayJobClusterConfig,
96+
) -> Dict[str, Any]:
97+
"""
98+
Build the RayCluster spec from RayJobClusterConfig for embedding in RayJob.
99+
100+
Args:
101+
cluster_config: The cluster configuration object (RayJobClusterConfig)
102+
103+
Returns:
104+
Dict containing the RayCluster spec for embedding in RayJob
105+
"""
106+
# Create a copy to avoid modifying the original
107+
temp_config = copy.copy(cluster_config)
108+
109+
temp_config.appwrapper = False
110+
temp_config.write_to_file = False
111+
112+
ray_cluster_spec = {
113+
"rayVersion": CUDA_RUNTIME_IMAGE,
114+
"enableInTreeAutoscaling": False,
115+
"headGroupSpec": _build_head_group_spec(temp_config),
116+
"workerGroupSpecs": [_build_worker_group_spec(temp_config)],
117+
}
118+
119+
# Add GCS fault tolerance if enabled
120+
if temp_config.enable_gcs_ft:
121+
gcs_ft_options = _build_gcs_ft_options(temp_config)
122+
ray_cluster_spec["gcsFaultToleranceOptions"] = gcs_ft_options
123+
124+
logger.info(f"Built RayCluster spec for cluster: {cluster_config.name}")
125+
return ray_cluster_spec
126+
127+
128+
def _build_head_group_spec(cluster_config: RayJobClusterConfig) -> Dict[str, Any]:
129+
"""Build the head group specification."""
130+
return {
131+
"serviceType": "ClusterIP",
132+
"enableIngress": False,
133+
"rayStartParams": _build_head_ray_params(cluster_config),
134+
"template": V1PodTemplateSpec(
135+
metadata=V1ObjectMeta(
136+
annotations=(
137+
cluster_config.annotations
138+
if hasattr(cluster_config, "annotations")
139+
else None
140+
)
141+
),
142+
spec=_build_pod_spec(
143+
cluster_config, _build_head_container(cluster_config), is_head=True
144+
),
145+
),
146+
}
147+
148+
149+
def _build_worker_group_spec(cluster_config: RayJobClusterConfig) -> Dict[str, Any]:
150+
"""Build the worker group specification."""
151+
return {
152+
"replicas": cluster_config.num_workers,
153+
"minReplicas": cluster_config.num_workers,
154+
"maxReplicas": cluster_config.num_workers,
155+
"groupName": f"worker-group-{cluster_config.name}",
156+
"rayStartParams": _build_worker_ray_params(cluster_config),
157+
"template": V1PodTemplateSpec(
158+
metadata=V1ObjectMeta(
159+
annotations=(
160+
cluster_config.annotations
161+
if hasattr(cluster_config, "annotations")
162+
else None
163+
)
164+
),
165+
spec=_build_pod_spec(
166+
cluster_config, _build_worker_container(cluster_config), is_head=False
167+
),
168+
),
169+
}
170+
171+
172+
def _build_head_ray_params(cluster_config: RayJobClusterConfig) -> Dict[str, str]:
173+
"""Build Ray start parameters for head node."""
174+
params = {
175+
"dashboard-host": "0.0.0.0",
176+
"dashboard-port": "8265",
177+
"block": "true",
178+
}
179+
180+
# Add GPU count if specified
181+
if (
182+
hasattr(cluster_config, "head_accelerators")
183+
and cluster_config.head_accelerators
184+
):
185+
gpu_count = sum(
186+
count
187+
for resource_type, count in cluster_config.head_accelerators.items()
188+
if "gpu" in resource_type.lower()
189+
)
190+
if gpu_count > 0:
191+
params["num-gpus"] = str(gpu_count)
192+
193+
return params
194+
195+
196+
def _build_worker_ray_params(cluster_config: RayJobClusterConfig) -> Dict[str, str]:
197+
"""Build Ray start parameters for worker nodes."""
198+
params = {
199+
"block": "true",
200+
}
201+
202+
# Add GPU count if specified
203+
if (
204+
hasattr(cluster_config, "worker_accelerators")
205+
and cluster_config.worker_accelerators
206+
):
207+
gpu_count = sum(
208+
count
209+
for resource_type, count in cluster_config.worker_accelerators.items()
210+
if "gpu" in resource_type.lower()
211+
)
212+
if gpu_count > 0:
213+
params["num-gpus"] = str(gpu_count)
214+
215+
return params
216+
217+
218+
def _build_head_container(cluster_config: RayJobClusterConfig) -> V1Container:
219+
"""Build the head container specification."""
220+
container = V1Container(
221+
name="ray-head",
222+
image=cluster_config.image or CUDA_RUNTIME_IMAGE,
223+
image_pull_policy="IfNotPresent", # Always IfNotPresent for RayJobs
224+
ports=[
225+
V1ContainerPort(name="gcs", container_port=6379),
226+
V1ContainerPort(name="dashboard", container_port=8265),
227+
V1ContainerPort(name="client", container_port=10001),
228+
],
229+
lifecycle=V1Lifecycle(
230+
pre_stop=V1LifecycleHandler(
231+
_exec=V1ExecAction(command=["/bin/sh", "-c", "ray stop"])
232+
)
233+
),
234+
resources=_build_resource_requirements(
235+
cluster_config.head_cpu_requests,
236+
cluster_config.head_cpu_limits,
237+
cluster_config.head_memory_requests,
238+
cluster_config.head_memory_limits,
239+
cluster_config.head_accelerators,
240+
),
241+
volume_mounts=_generate_volume_mounts(cluster_config),
242+
)
243+
244+
# Add environment variables if specified
245+
if hasattr(cluster_config, "envs") and cluster_config.envs:
246+
container.env = _build_env_vars(cluster_config.envs)
247+
248+
return container
249+
250+
251+
def _build_worker_container(cluster_config: RayJobClusterConfig) -> V1Container:
252+
"""Build the worker container specification."""
253+
container = V1Container(
254+
name="ray-worker",
255+
image=cluster_config.image or CUDA_RUNTIME_IMAGE,
256+
image_pull_policy="IfNotPresent", # Always IfNotPresent for RayJobs
257+
lifecycle=V1Lifecycle(
258+
pre_stop=V1LifecycleHandler(
259+
_exec=V1ExecAction(command=["/bin/sh", "-c", "ray stop"])
260+
)
261+
),
262+
resources=_build_resource_requirements(
263+
cluster_config.worker_cpu_requests,
264+
cluster_config.worker_cpu_limits,
265+
cluster_config.worker_memory_requests,
266+
cluster_config.worker_memory_limits,
267+
cluster_config.worker_accelerators,
268+
),
269+
volume_mounts=_generate_volume_mounts(cluster_config),
270+
)
271+
272+
# Add environment variables if specified
273+
if hasattr(cluster_config, "envs") and cluster_config.envs:
274+
container.env = _build_env_vars(cluster_config.envs)
275+
276+
return container
277+
278+
279+
def _build_resource_requirements(
280+
cpu_requests: Union[int, str],
281+
cpu_limits: Union[int, str],
282+
memory_requests: Union[int, str],
283+
memory_limits: Union[int, str],
284+
extended_resource_requests: Dict[str, Union[int, str]] = None,
285+
) -> V1ResourceRequirements:
286+
"""Build Kubernetes resource requirements."""
287+
resource_requirements = V1ResourceRequirements(
288+
requests={"cpu": cpu_requests, "memory": memory_requests},
289+
limits={"cpu": cpu_limits, "memory": memory_limits},
290+
)
291+
292+
# Add extended resources (e.g., GPUs)
293+
if extended_resource_requests:
294+
for resource_type, amount in extended_resource_requests.items():
295+
resource_requirements.limits[resource_type] = amount
296+
resource_requirements.requests[resource_type] = amount
297+
298+
return resource_requirements
299+
300+
301+
def _build_pod_spec(
302+
cluster_config: RayJobClusterConfig, container: V1Container, is_head: bool
303+
) -> V1PodSpec:
304+
"""Build the pod specification."""
305+
pod_spec = V1PodSpec(
306+
containers=[container],
307+
volumes=_generate_volumes(cluster_config),
308+
restart_policy="Never", # RayJobs should not restart
309+
)
310+
311+
# Add tolerations if specified
312+
if (
313+
is_head
314+
and hasattr(cluster_config, "head_tolerations")
315+
and cluster_config.head_tolerations
316+
):
317+
pod_spec.tolerations = cluster_config.head_tolerations
318+
elif (
319+
not is_head
320+
and hasattr(cluster_config, "worker_tolerations")
321+
and cluster_config.worker_tolerations
322+
):
323+
pod_spec.tolerations = cluster_config.worker_tolerations
324+
325+
# Add image pull secrets if specified
326+
if (
327+
hasattr(cluster_config, "image_pull_secrets")
328+
and cluster_config.image_pull_secrets
329+
):
330+
from kubernetes.client import V1LocalObjectReference
331+
332+
pod_spec.image_pull_secrets = [
333+
V1LocalObjectReference(name=secret)
334+
for secret in cluster_config.image_pull_secrets
335+
]
336+
337+
return pod_spec
338+
339+
340+
def _generate_volume_mounts(cluster_config: RayJobClusterConfig) -> list:
341+
"""Generate volume mounts for the container."""
342+
volume_mounts = DEFAULT_VOLUME_MOUNTS.copy()
343+
344+
# Add custom volume mounts if specified
345+
if hasattr(cluster_config, "volume_mounts") and cluster_config.volume_mounts:
346+
volume_mounts.extend(cluster_config.volume_mounts)
347+
348+
return volume_mounts
349+
350+
351+
def _generate_volumes(cluster_config: RayJobClusterConfig) -> list:
352+
"""Generate volumes for the pod."""
353+
volumes = DEFAULT_VOLUMES.copy()
354+
355+
# Add custom volumes if specified
356+
if hasattr(cluster_config, "volumes") and cluster_config.volumes:
357+
volumes.extend(cluster_config.volumes)
358+
359+
return volumes
360+
361+
362+
def _build_env_vars(envs: Dict[str, str]) -> list:
363+
"""Build environment variables list."""
364+
return [V1EnvVar(name=key, value=value) for key, value in envs.items()]
365+
366+
367+
def _build_gcs_ft_options(cluster_config: RayJobClusterConfig) -> Dict[str, Any]:
368+
"""Build GCS fault tolerance options."""
369+
gcs_ft_options = {"redisAddress": cluster_config.redis_address}
370+
371+
if (
372+
hasattr(cluster_config, "external_storage_namespace")
373+
and cluster_config.external_storage_namespace
374+
):
375+
gcs_ft_options[
376+
"externalStorageNamespace"
377+
] = cluster_config.external_storage_namespace
378+
379+
if (
380+
hasattr(cluster_config, "redis_password_secret")
381+
and cluster_config.redis_password_secret
382+
):
383+
gcs_ft_options["redisPassword"] = {
384+
"valueFrom": {
385+
"secretKeyRef": {
386+
"name": cluster_config.redis_password_secret["name"],
387+
"key": cluster_config.redis_password_secret["key"],
388+
}
389+
}
390+
}
391+
392+
return gcs_ft_options

0 commit comments

Comments
 (0)