|
| 1 | +# Copyright 2025 IBM, Red Hat |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +""" |
| 16 | +RayJobClusterConfig provides a focused configuration class for creating Ray clusters |
| 17 | +as part of RayJob submissions. This class maps directly to the KubeRay RayClusterSpec |
| 18 | +structure and removes legacy fields that aren't relevant for RayJob-based cluster creation. |
| 19 | +""" |
| 20 | + |
| 21 | +from dataclasses import dataclass, field |
| 22 | +from typing import Dict, Optional, Union |
| 23 | + |
| 24 | +from codeflare_sdk.common.utils.constants import CUDA_RUNTIME_IMAGE |
| 25 | + |
| 26 | + |
| 27 | +@dataclass |
| 28 | +class RayJobClusterConfig: |
| 29 | + """ |
| 30 | + Configuration for creating a RayCluster as part of a RayJob submission. |
| 31 | +
|
| 32 | + This class provides a clean, focused interface with sensible defaults that covers |
| 33 | + 95% of use cases. For advanced configurations, users can still access the underlying |
| 34 | + Kubernetes objects. |
| 35 | +
|
| 36 | + Args: |
| 37 | + num_workers: Number of worker nodes to create |
| 38 | + head_cpu: CPU requests/limits for head node (e.g., "2" or 2) |
| 39 | + head_memory: Memory requests/limits for head node (e.g., "8Gi" or 8) |
| 40 | + gpu: GPU resources for head node (e.g., {"nvidia.com/gpu": 1}) |
| 41 | + worker_cpu: CPU requests/limits for each worker (e.g., "1" or 1) |
| 42 | + worker_memory: Memory requests/limits for each worker (e.g., "4Gi" or 4) |
| 43 | + worker_gpu: GPU resources for each worker (e.g., {"nvidia.com/gpu": 1}) |
| 44 | + image: Container image for all nodes (defaults to CUDA-enabled Ray image) |
| 45 | +
|
| 46 | + # Advanced options (optional) |
| 47 | + envs: Additional environment variables to set on all pods |
| 48 | + enable_gcs_ft: Whether to enable GCS fault tolerance |
| 49 | + redis_address: Redis address for GCS fault tolerance (required if enable_gcs_ft=True) |
| 50 | +
|
| 51 | + # Lifecycle management |
| 52 | + shutdown_after_job_finishes: Whether to automatically cleanup the cluster after job completion |
| 53 | + ttl_seconds_after_finished: Seconds to wait before cleanup after job finishes |
| 54 | + active_deadline_seconds: Maximum time the job can run before being terminated |
| 55 | + """ |
| 56 | + |
| 57 | + num_workers: int = 1 |
| 58 | + head_cpu: Union[int, str] = 2 |
| 59 | + head_memory: Union[int, str] = 8 |
| 60 | + gpu: Dict[str, Union[int, str]] = field(default_factory=dict) |
| 61 | + worker_cpu: Union[int, str] = 1 |
| 62 | + worker_memory: Union[int, str] = 4 |
| 63 | + worker_gpu: Dict[str, Union[int, str]] = field(default_factory=dict) |
| 64 | + image: str = CUDA_RUNTIME_IMAGE # Use CUDA-enabled Ray image |
| 65 | + |
| 66 | + # Advanced options |
| 67 | + envs: Dict[str, str] = field(default_factory=dict) |
| 68 | + enable_gcs_ft: bool = False |
| 69 | + redis_address: Optional[str] = None |
| 70 | + |
| 71 | + # Lifecycle management |
| 72 | + shutdown_after_job_finishes: bool = True |
| 73 | + ttl_seconds_after_finished: int = 0 |
| 74 | + active_deadline_seconds: Optional[int] = None |
| 75 | + |
| 76 | + def __post_init__(self): |
| 77 | + """Post-initialization validation and setup.""" |
| 78 | + self._validate_config() |
| 79 | + self._normalize_resources() |
| 80 | + self._setup_gcs_ft() |
| 81 | + self._setup_default_storage() |
| 82 | + self._setup_default_labels() |
| 83 | + self._setup_default_ray_params() |
| 84 | + |
| 85 | + def _validate_gpu_resources(self): |
| 86 | + """Validate GPU resource specifications.""" |
| 87 | + # Validate head GPU resources |
| 88 | + for gpu_type, gpu_count in self.gpu.items(): |
| 89 | + if not isinstance(gpu_count, (int, str)): |
| 90 | + raise ValueError( |
| 91 | + f"GPU count for {gpu_type} must be int or str, got {type(gpu_count)}" |
| 92 | + ) |
| 93 | + if isinstance(gpu_count, str) and not gpu_count.isdigit(): |
| 94 | + raise ValueError( |
| 95 | + f"GPU count string for {gpu_type} must be numeric, got {gpu_count}" |
| 96 | + ) |
| 97 | + |
| 98 | + # Validate worker GPU resources |
| 99 | + for gpu_type, gpu_count in self.worker_gpu.items(): |
| 100 | + if not isinstance(gpu_count, (int, str)): |
| 101 | + raise ValueError( |
| 102 | + f"GPU count for {gpu_type} must be int or str, got {type(gpu_count)}" |
| 103 | + ) |
| 104 | + if isinstance(gpu_count, str) and not gpu_count.isdigit(): |
| 105 | + raise ValueError( |
| 106 | + f"GPU count string for {gpu_type} must be numeric, got {gpu_count}" |
| 107 | + ) |
| 108 | + |
| 109 | + def _validate_config(self): |
| 110 | + """Validate configuration parameters.""" |
| 111 | + if self.enable_gcs_ft and not self.redis_address: |
| 112 | + raise ValueError( |
| 113 | + "redis_address must be provided when enable_gcs_ft is True" |
| 114 | + ) |
| 115 | + |
| 116 | + # GCS FT validation simplified - only redis_address is required |
| 117 | + if self.num_workers < 0: |
| 118 | + raise ValueError("num_workers cannot be negative") |
| 119 | + |
| 120 | + # Validate GPU resources |
| 121 | + self._validate_gpu_resources() |
| 122 | + |
| 123 | + def _normalize_resources(self): |
| 124 | + """Normalize resource specifications to string format.""" |
| 125 | + # Convert head resources |
| 126 | + if isinstance(self.head_cpu, int): |
| 127 | + self.head_cpu = str(self.head_cpu) |
| 128 | + if isinstance(self.head_memory, int): |
| 129 | + self.head_memory = f"{self.head_memory}Gi" |
| 130 | + |
| 131 | + # Convert worker resources |
| 132 | + if isinstance(self.worker_cpu, int): |
| 133 | + self.worker_cpu = str(self.worker_cpu) |
| 134 | + if isinstance(self.worker_memory, int): |
| 135 | + self.worker_memory = f"{self.worker_memory}Gi" |
| 136 | + |
| 137 | + def _setup_gcs_ft(self): |
| 138 | + """Setup GCS fault tolerance environment variables.""" |
| 139 | + if self.enable_gcs_ft: |
| 140 | + self.envs["RAY_GCS_FT_ENABLED"] = "true" |
| 141 | + if self.redis_address: |
| 142 | + self.envs["RAY_REDIS_ADDRESS"] = self.redis_address |
| 143 | + |
| 144 | + def _setup_default_storage(self): |
| 145 | + """Setup default storage - simplified for most use cases.""" |
| 146 | + # Most users don't need custom volumes/mounts |
| 147 | + pass |
| 148 | + |
| 149 | + def _setup_default_labels(self): |
| 150 | + """Setup default labels for the cluster.""" |
| 151 | + # Simplified labels - most users don't need custom ones |
| 152 | + pass |
| 153 | + |
| 154 | + def _setup_default_ray_params(self): |
| 155 | + """Setup default Ray start parameters for all nodes.""" |
| 156 | + # These will be used in the to_dict() method |
| 157 | + self._default_head_params = { |
| 158 | + "dashboard-host": "0.0.0.0", |
| 159 | + "dashboard-port": "8265", |
| 160 | + "block": "true", |
| 161 | + } |
| 162 | + |
| 163 | + self._default_worker_params = { |
| 164 | + "block": "true", |
| 165 | + } |
| 166 | + |
| 167 | + def to_dict(self) -> Dict: |
| 168 | + """ |
| 169 | + Convert the configuration to a dictionary that can be used |
| 170 | + to create the RayClusterSpec for a RayJob. |
| 171 | +
|
| 172 | + Returns: |
| 173 | + Dictionary representation suitable for RayJob rayClusterSpec |
| 174 | + """ |
| 175 | + config_dict = { |
| 176 | + "rayVersion": "2.9.0", # Use stable version |
| 177 | + "headGroupSpec": self._build_head_group_spec(), |
| 178 | + } |
| 179 | + |
| 180 | + if self.num_workers > 0: |
| 181 | + config_dict["workerGroupSpecs"] = [self._build_worker_group_spec()] |
| 182 | + |
| 183 | + if self.enable_gcs_ft: |
| 184 | + config_dict["gcsFaultToleranceOptions"] = self._build_gcs_ft_options() |
| 185 | + |
| 186 | + return config_dict |
| 187 | + |
| 188 | + def _build_head_group_spec(self) -> Dict: |
| 189 | + """Build the HeadGroupSpec for the RayCluster.""" |
| 190 | + head_spec = { |
| 191 | + "template": self._build_pod_template( |
| 192 | + cpu=self.head_cpu, |
| 193 | + memory=self.head_memory, |
| 194 | + gpu=self.gpu, |
| 195 | + image=self.image, |
| 196 | + is_head=True, |
| 197 | + ), |
| 198 | + "rayStartParams": self._default_head_params, |
| 199 | + "serviceType": "ClusterIP", # Always use ClusterIP |
| 200 | + } |
| 201 | + |
| 202 | + return head_spec |
| 203 | + |
| 204 | + def _build_worker_group_spec(self) -> Dict: |
| 205 | + """Build the WorkerGroupSpec for the RayCluster.""" |
| 206 | + worker_spec = { |
| 207 | + "groupName": "default-worker-group", |
| 208 | + "replicas": self.num_workers, |
| 209 | + "template": self._build_pod_template( |
| 210 | + cpu=self.worker_cpu, |
| 211 | + memory=self.worker_memory, |
| 212 | + gpu=self.worker_gpu, |
| 213 | + image=self.image, |
| 214 | + is_head=False, |
| 215 | + ), |
| 216 | + "rayStartParams": self._default_worker_params, |
| 217 | + } |
| 218 | + |
| 219 | + return worker_spec |
| 220 | + |
| 221 | + def _build_pod_template( |
| 222 | + self, |
| 223 | + cpu: str, |
| 224 | + memory: str, |
| 225 | + gpu: Dict[str, Union[int, str]], |
| 226 | + image: str, |
| 227 | + is_head: bool, |
| 228 | + ) -> Dict: |
| 229 | + """Build a pod template specification.""" |
| 230 | + # Build resource requests and limits |
| 231 | + resources = { |
| 232 | + "requests": { |
| 233 | + "cpu": cpu, |
| 234 | + "memory": memory, |
| 235 | + }, |
| 236 | + "limits": { |
| 237 | + "cpu": cpu, |
| 238 | + "memory": memory, |
| 239 | + }, |
| 240 | + } |
| 241 | + |
| 242 | + # Add GPU resources if specified |
| 243 | + for gpu_type, gpu_count in gpu.items(): |
| 244 | + resources["requests"][gpu_type] = str(gpu_count) |
| 245 | + resources["limits"][gpu_type] = str(gpu_count) |
| 246 | + |
| 247 | + # Build container spec |
| 248 | + container = { |
| 249 | + "name": "ray-head" if is_head else "ray-worker", |
| 250 | + "image": image, |
| 251 | + "imagePullPolicy": "IfNotPresent", |
| 252 | + "resources": resources, |
| 253 | + "env": [{"name": k, "value": v} for k, v in self.envs.items()], |
| 254 | + } |
| 255 | + |
| 256 | + # Add head node specific configuration |
| 257 | + if is_head: |
| 258 | + container["ports"] = [ |
| 259 | + {"name": "gcs", "containerPort": 6379}, |
| 260 | + {"name": "dashboard", "containerPort": 8265}, |
| 261 | + {"name": "client", "containerPort": 10001}, |
| 262 | + ] |
| 263 | + container["lifecycle"] = { |
| 264 | + "preStop": {"exec": {"command": ["/bin/sh", "-c", "ray stop"]}} |
| 265 | + } |
| 266 | + else: |
| 267 | + # Add worker lifecycle hook |
| 268 | + container["lifecycle"] = { |
| 269 | + "preStop": {"exec": {"command": ["/bin/sh", "-c", "ray stop"]}} |
| 270 | + } |
| 271 | + |
| 272 | + # Build pod template - simplified for most use cases |
| 273 | + pod_template = { |
| 274 | + "spec": { |
| 275 | + "containers": [container], |
| 276 | + "restartPolicy": "Never", # RayJobs manage lifecycle, so Never is appropriate |
| 277 | + } |
| 278 | + } |
| 279 | + |
| 280 | + return pod_template |
| 281 | + |
| 282 | + def _build_gcs_ft_options(self) -> Dict: |
| 283 | + """Build GCS fault tolerance options for the RayCluster.""" |
| 284 | + return { |
| 285 | + "redisAddress": self.redis_address, |
| 286 | + } |
0 commit comments