Skip to content

Commit d3803bf

Browse files
committed
feat(RHOAIENG-26482): add gcp fault tolerance
1 parent 6585567 commit d3803bf

File tree

5 files changed

+421
-34
lines changed

5 files changed

+421
-34
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .rayjob import RayJob, RayJobClusterConfig
22
from .status import RayJobDeploymentStatus, CodeflareRayJobStatus, RayJobInfo
3+
from .cluster_config import RayJobClusterConfig
Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
# Copyright 2025 IBM, Red Hat
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
RayJobClusterConfig provides a focused configuration class for creating Ray clusters
17+
as part of RayJob submissions. This class maps directly to the KubeRay RayClusterSpec
18+
structure and removes legacy fields that aren't relevant for RayJob-based cluster creation.
19+
"""
20+
21+
from dataclasses import dataclass, field
22+
from typing import Dict, Optional, Union
23+
24+
from codeflare_sdk.common.utils.constants import CUDA_RUNTIME_IMAGE
25+
26+
27+
@dataclass
28+
class RayJobClusterConfig:
29+
"""
30+
Configuration for creating a RayCluster as part of a RayJob submission.
31+
32+
This class provides a clean, focused interface with sensible defaults that covers
33+
95% of use cases. For advanced configurations, users can still access the underlying
34+
Kubernetes objects.
35+
36+
Args:
37+
num_workers: Number of worker nodes to create
38+
head_cpu: CPU requests/limits for head node (e.g., "2" or 2)
39+
head_memory: Memory requests/limits for head node (e.g., "8Gi" or 8)
40+
gpu: GPU resources for head node (e.g., {"nvidia.com/gpu": 1})
41+
worker_cpu: CPU requests/limits for each worker (e.g., "1" or 1)
42+
worker_memory: Memory requests/limits for each worker (e.g., "4Gi" or 4)
43+
worker_gpu: GPU resources for each worker (e.g., {"nvidia.com/gpu": 1})
44+
image: Container image for all nodes (defaults to CUDA-enabled Ray image)
45+
46+
# Advanced options (optional)
47+
envs: Additional environment variables to set on all pods
48+
enable_gcs_ft: Whether to enable GCS fault tolerance
49+
redis_address: Redis address for GCS fault tolerance (required if enable_gcs_ft=True)
50+
51+
# Lifecycle management
52+
shutdown_after_job_finishes: Whether to automatically cleanup the cluster after job completion
53+
ttl_seconds_after_finished: Seconds to wait before cleanup after job finishes
54+
active_deadline_seconds: Maximum time the job can run before being terminated
55+
"""
56+
57+
num_workers: int = 1
58+
head_cpu: Union[int, str] = 2
59+
head_memory: Union[int, str] = 8
60+
gpu: Dict[str, Union[int, str]] = field(default_factory=dict)
61+
worker_cpu: Union[int, str] = 1
62+
worker_memory: Union[int, str] = 4
63+
worker_gpu: Dict[str, Union[int, str]] = field(default_factory=dict)
64+
image: str = CUDA_RUNTIME_IMAGE # Use CUDA-enabled Ray image
65+
66+
# Advanced options
67+
envs: Dict[str, str] = field(default_factory=dict)
68+
enable_gcs_ft: bool = False
69+
redis_address: Optional[str] = None
70+
71+
# Lifecycle management
72+
shutdown_after_job_finishes: bool = True
73+
ttl_seconds_after_finished: int = 0
74+
active_deadline_seconds: Optional[int] = None
75+
76+
def __post_init__(self):
77+
"""Post-initialization validation and setup."""
78+
self._validate_config()
79+
self._normalize_resources()
80+
self._setup_gcs_ft()
81+
self._setup_default_storage()
82+
self._setup_default_labels()
83+
self._setup_default_ray_params()
84+
85+
def _validate_gpu_resources(self):
86+
"""Validate GPU resource specifications."""
87+
# Validate head GPU resources
88+
for gpu_type, gpu_count in self.gpu.items():
89+
if not isinstance(gpu_count, (int, str)):
90+
raise ValueError(
91+
f"GPU count for {gpu_type} must be int or str, got {type(gpu_count)}"
92+
)
93+
if isinstance(gpu_count, str) and not gpu_count.isdigit():
94+
raise ValueError(
95+
f"GPU count string for {gpu_type} must be numeric, got {gpu_count}"
96+
)
97+
98+
# Validate worker GPU resources
99+
for gpu_type, gpu_count in self.worker_gpu.items():
100+
if not isinstance(gpu_count, (int, str)):
101+
raise ValueError(
102+
f"GPU count for {gpu_type} must be int or str, got {type(gpu_count)}"
103+
)
104+
if isinstance(gpu_count, str) and not gpu_count.isdigit():
105+
raise ValueError(
106+
f"GPU count string for {gpu_type} must be numeric, got {gpu_count}"
107+
)
108+
109+
def _validate_config(self):
110+
"""Validate configuration parameters."""
111+
if self.enable_gcs_ft and not self.redis_address:
112+
raise ValueError(
113+
"redis_address must be provided when enable_gcs_ft is True"
114+
)
115+
116+
# GCS FT validation simplified - only redis_address is required
117+
if self.num_workers < 0:
118+
raise ValueError("num_workers cannot be negative")
119+
120+
# Validate GPU resources
121+
self._validate_gpu_resources()
122+
123+
def _normalize_resources(self):
124+
"""Normalize resource specifications to string format."""
125+
# Convert head resources
126+
if isinstance(self.head_cpu, int):
127+
self.head_cpu = str(self.head_cpu)
128+
if isinstance(self.head_memory, int):
129+
self.head_memory = f"{self.head_memory}Gi"
130+
131+
# Convert worker resources
132+
if isinstance(self.worker_cpu, int):
133+
self.worker_cpu = str(self.worker_cpu)
134+
if isinstance(self.worker_memory, int):
135+
self.worker_memory = f"{self.worker_memory}Gi"
136+
137+
def _setup_gcs_ft(self):
138+
"""Setup GCS fault tolerance environment variables."""
139+
if self.enable_gcs_ft:
140+
self.envs["RAY_GCS_FT_ENABLED"] = "true"
141+
if self.redis_address:
142+
self.envs["RAY_REDIS_ADDRESS"] = self.redis_address
143+
144+
def _setup_default_storage(self):
145+
"""Setup default storage - simplified for most use cases."""
146+
# Most users don't need custom volumes/mounts
147+
pass
148+
149+
def _setup_default_labels(self):
150+
"""Setup default labels for the cluster."""
151+
# Simplified labels - most users don't need custom ones
152+
pass
153+
154+
def _setup_default_ray_params(self):
155+
"""Setup default Ray start parameters for all nodes."""
156+
# These will be used in the to_dict() method
157+
self._default_head_params = {
158+
"dashboard-host": "0.0.0.0",
159+
"dashboard-port": "8265",
160+
"block": "true",
161+
}
162+
163+
self._default_worker_params = {
164+
"block": "true",
165+
}
166+
167+
def to_dict(self) -> Dict:
168+
"""
169+
Convert the configuration to a dictionary that can be used
170+
to create the RayClusterSpec for a RayJob.
171+
172+
Returns:
173+
Dictionary representation suitable for RayJob rayClusterSpec
174+
"""
175+
config_dict = {
176+
"rayVersion": "2.9.0", # Use stable version
177+
"headGroupSpec": self._build_head_group_spec(),
178+
}
179+
180+
if self.num_workers > 0:
181+
config_dict["workerGroupSpecs"] = [self._build_worker_group_spec()]
182+
183+
if self.enable_gcs_ft:
184+
config_dict["gcsFaultToleranceOptions"] = self._build_gcs_ft_options()
185+
186+
return config_dict
187+
188+
def _build_head_group_spec(self) -> Dict:
189+
"""Build the HeadGroupSpec for the RayCluster."""
190+
head_spec = {
191+
"template": self._build_pod_template(
192+
cpu=self.head_cpu,
193+
memory=self.head_memory,
194+
gpu=self.gpu,
195+
image=self.image,
196+
is_head=True,
197+
),
198+
"rayStartParams": self._default_head_params,
199+
"serviceType": "ClusterIP", # Always use ClusterIP
200+
}
201+
202+
return head_spec
203+
204+
def _build_worker_group_spec(self) -> Dict:
205+
"""Build the WorkerGroupSpec for the RayCluster."""
206+
worker_spec = {
207+
"groupName": "default-worker-group",
208+
"replicas": self.num_workers,
209+
"template": self._build_pod_template(
210+
cpu=self.worker_cpu,
211+
memory=self.worker_memory,
212+
gpu=self.worker_gpu,
213+
image=self.image,
214+
is_head=False,
215+
),
216+
"rayStartParams": self._default_worker_params,
217+
}
218+
219+
return worker_spec
220+
221+
def _build_pod_template(
222+
self,
223+
cpu: str,
224+
memory: str,
225+
gpu: Dict[str, Union[int, str]],
226+
image: str,
227+
is_head: bool,
228+
) -> Dict:
229+
"""Build a pod template specification."""
230+
# Build resource requests and limits
231+
resources = {
232+
"requests": {
233+
"cpu": cpu,
234+
"memory": memory,
235+
},
236+
"limits": {
237+
"cpu": cpu,
238+
"memory": memory,
239+
},
240+
}
241+
242+
# Add GPU resources if specified
243+
for gpu_type, gpu_count in gpu.items():
244+
resources["requests"][gpu_type] = str(gpu_count)
245+
resources["limits"][gpu_type] = str(gpu_count)
246+
247+
# Build container spec
248+
container = {
249+
"name": "ray-head" if is_head else "ray-worker",
250+
"image": image,
251+
"imagePullPolicy": "IfNotPresent",
252+
"resources": resources,
253+
"env": [{"name": k, "value": v} for k, v in self.envs.items()],
254+
}
255+
256+
# Add head node specific configuration
257+
if is_head:
258+
container["ports"] = [
259+
{"name": "gcs", "containerPort": 6379},
260+
{"name": "dashboard", "containerPort": 8265},
261+
{"name": "client", "containerPort": 10001},
262+
]
263+
container["lifecycle"] = {
264+
"preStop": {"exec": {"command": ["/bin/sh", "-c", "ray stop"]}}
265+
}
266+
else:
267+
# Add worker lifecycle hook
268+
container["lifecycle"] = {
269+
"preStop": {"exec": {"command": ["/bin/sh", "-c", "ray stop"]}}
270+
}
271+
272+
# Build pod template - simplified for most use cases
273+
pod_template = {
274+
"spec": {
275+
"containers": [container],
276+
"restartPolicy": "Never", # RayJobs manage lifecycle, so Never is appropriate
277+
}
278+
}
279+
280+
return pod_template
281+
282+
def _build_gcs_ft_options(self) -> Dict:
283+
"""Build GCS fault tolerance options for the RayCluster."""
284+
return {
285+
"redisAddress": self.redis_address,
286+
}

src/codeflare_sdk/ray/rayjobs/config.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
"""
16-
The config sub-module contains the definition of the RayJobClusterConfigV2 dataclass,
16+
The config sub-module contains the definition of the RayJobClusterConfig dataclass,
1717
which is used to specify resource requirements and other details when creating a
1818
Cluster object.
1919
"""
@@ -139,6 +139,16 @@ class RayJobClusterConfig:
139139
A list of V1Volume objects to add to the Cluster
140140
volume_mounts:
141141
A list of V1VolumeMount objects to add to the Cluster
142+
enable_gcs_ft:
143+
A boolean indicating whether to enable GCS fault tolerance.
144+
enable_usage_stats:
145+
A boolean indicating whether to capture and send Ray usage stats externally.
146+
redis_address:
147+
The address of the Redis server to use for GCS fault tolerance, required when enable_gcs_ft is True.
148+
redis_password_secret:
149+
Kubernetes secret reference containing Redis password. ex: {"name": "secret-name", "key": "password-key"}
150+
external_storage_namespace:
151+
The storage namespace to use for GCS fault tolerance. By default, KubeRay sets it to the UID of RayCluster.
142152
"""
143153

144154
head_cpu_requests: Union[int, str] = 2
@@ -165,8 +175,39 @@ class RayJobClusterConfig:
165175
annotations: Dict[str, str] = field(default_factory=dict)
166176
volumes: list[V1Volume] = field(default_factory=list)
167177
volume_mounts: list[V1VolumeMount] = field(default_factory=list)
178+
enable_gcs_ft: bool = False
179+
enable_usage_stats: bool = False
180+
redis_address: Optional[str] = None
181+
redis_password_secret: Optional[Dict[str, str]] = None
182+
external_storage_namespace: Optional[str] = None
168183

169184
def __post_init__(self):
185+
if self.enable_usage_stats:
186+
self.envs["RAY_USAGE_STATS_ENABLED"] = "1"
187+
else:
188+
self.envs["RAY_USAGE_STATS_ENABLED"] = "0"
189+
190+
if self.enable_gcs_ft:
191+
if not self.redis_address:
192+
raise ValueError(
193+
"redis_address must be provided when enable_gcs_ft is True"
194+
)
195+
196+
if self.redis_password_secret and not isinstance(
197+
self.redis_password_secret, dict
198+
):
199+
raise ValueError(
200+
"redis_password_secret must be a dictionary with 'name' and 'key' fields"
201+
)
202+
203+
if self.redis_password_secret and (
204+
"name" not in self.redis_password_secret
205+
or "key" not in self.redis_password_secret
206+
):
207+
raise ValueError(
208+
"redis_password_secret must contain both 'name' and 'key' fields"
209+
)
210+
170211
self._validate_types()
171212
self._memory_to_string()
172213
self._validate_gpu_config(self.head_accelerators)
@@ -251,6 +292,11 @@ def build_ray_cluster_spec(self, cluster_name: str) -> Dict[str, Any]:
251292
"workerGroupSpecs": [self._build_worker_group_spec(cluster_name)],
252293
}
253294

295+
# Add GCS fault tolerance if enabled
296+
if self.enable_gcs_ft:
297+
gcs_ft_options = self._build_gcs_ft_options()
298+
ray_cluster_spec["gcsFaultToleranceOptions"] = gcs_ft_options
299+
254300
return ray_cluster_spec
255301

256302
def _build_head_group_spec(self) -> Dict[str, Any]:
@@ -453,3 +499,25 @@ def _generate_volumes(self) -> list:
453499
def _build_env_vars(self) -> list:
454500
"""Build environment variables list."""
455501
return [V1EnvVar(name=key, value=value) for key, value in self.envs.items()]
502+
503+
def _build_gcs_ft_options(self) -> Dict[str, Any]:
504+
"""Build GCS fault tolerance options."""
505+
gcs_ft_options = {"redisAddress": self.redis_address}
506+
507+
if (
508+
hasattr(self, "external_storage_namespace")
509+
and self.external_storage_namespace
510+
):
511+
gcs_ft_options["externalStorageNamespace"] = self.external_storage_namespace
512+
513+
if hasattr(self, "redis_password_secret") and self.redis_password_secret:
514+
gcs_ft_options["redisPassword"] = {
515+
"valueFrom": {
516+
"secretKeyRef": {
517+
"name": self.redis_password_secret["name"],
518+
"key": self.redis_password_secret["key"],
519+
}
520+
}
521+
}
522+
523+
return gcs_ft_options

0 commit comments

Comments
 (0)