Skip to content

Commit 1166939

Browse files
committed
feat(RHOAIENG-29354): Set ImagePullPolicy to IfNotPresent
1 parent f69c792 commit 1166939

File tree

2 files changed

+322
-0
lines changed

2 files changed

+322
-0
lines changed
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
# Copyright 2022 IBM, Red Hat
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
The config sub-module contains the definition of the ClusterConfigurationV2V2 dataclass,
17+
which is used to specify resource requirements and other details when creating a
18+
Cluster object.
19+
"""
20+
21+
import pathlib
22+
import warnings
23+
from dataclasses import dataclass, field, fields
24+
from typing import Dict, List, Optional, Union, get_args, get_origin
25+
from kubernetes.client import V1Toleration, V1Volume, V1VolumeMount
26+
27+
dir = pathlib.Path(__file__).parent.parent.resolve()
28+
29+
# https://docs.ray.io/en/latest/ray-core/scheduling/accelerators.html
30+
DEFAULT_RESOURCE_MAPPING = {
31+
"nvidia.com/gpu": "GPU",
32+
"intel.com/gpu": "GPU",
33+
"amd.com/gpu": "GPU",
34+
"aws.amazon.com/neuroncore": "neuron_cores",
35+
"google.com/tpu": "TPU",
36+
"habana.ai/gaudi": "HPU",
37+
"huawei.com/Ascend910": "NPU",
38+
"huawei.com/Ascend310": "NPU",
39+
}
40+
41+
42+
@dataclass
43+
class ClusterConfigurationV2:
44+
"""
45+
This dataclass is used to specify resource requirements and other details, and
46+
is passed in as an argument when creating a Cluster object.
47+
48+
Args:
49+
name:
50+
The name of the cluster.
51+
namespace:
52+
The namespace in which the cluster should be created.
53+
head_cpus:
54+
The number of CPUs to allocate to the head node.
55+
head_memory:
56+
The amount of memory to allocate to the head node.
57+
head_extended_resource_requests:
58+
A dictionary of extended resource requests for the head node. ex: {"nvidia.com/gpu": 1}
59+
head_tolerations:
60+
List of tolerations for head nodes.
61+
num_workers:
62+
The number of workers to create.
63+
worker_tolerations:
64+
List of tolerations for worker nodes.
65+
appwrapper:
66+
A boolean indicating whether to use an AppWrapper.
67+
envs:
68+
A dictionary of environment variables to set for the cluster.
69+
image:
70+
The image to use for the cluster.
71+
image_pull_secrets:
72+
A list of image pull secrets to use for the cluster.
73+
write_to_file:
74+
A boolean indicating whether to write the cluster configuration to a file.
75+
verify_tls:
76+
A boolean indicating whether to verify TLS when connecting to the cluster.
77+
labels:
78+
A dictionary of labels to apply to the cluster.
79+
worker_extended_resource_requests:
80+
A dictionary of extended resource requests for each worker. ex: {"nvidia.com/gpu": 1}
81+
extended_resource_mapping:
82+
A dictionary of custom resource mappings to map extended resource requests to RayCluster resource names
83+
overwrite_default_resource_mapping:
84+
A boolean indicating whether to overwrite the default resource mapping.
85+
annotations:
86+
A dictionary of annotations to apply to the cluster.
87+
volumes:
88+
A list of V1Volume objects to add to the Cluster
89+
volume_mounts:
90+
A list of V1VolumeMount objects to add to the Cluster
91+
enable_gcs_ft:
92+
A boolean indicating whether to enable GCS fault tolerance.
93+
enable_usage_stats:
94+
A boolean indicating whether to capture and send Ray usage stats externally.
95+
redis_address:
96+
The address of the Redis server to use for GCS fault tolerance, required when enable_gcs_ft is True.
97+
redis_password_secret:
98+
Kubernetes secret reference containing Redis password. ex: {"name": "secret-name", "key": "password-key"}
99+
external_storage_namespace:
100+
The storage namespace to use for GCS fault tolerance. By default, KubeRay sets it to the UID of RayCluster.
101+
"""
102+
103+
name: str
104+
namespace: Optional[str] = None
105+
head_cpu_requests: Union[int, str] = 2
106+
head_cpu_limits: Union[int, str] = 2
107+
head_cpus: Optional[Union[int, str]] = None # Deprecating
108+
head_memory_requests: Union[int, str] = 8
109+
head_memory_limits: Union[int, str] = 8
110+
head_memory: Optional[Union[int, str]] = None # Deprecating
111+
head_extended_resource_requests: Dict[str, Union[str, int]] = field(
112+
default_factory=dict
113+
)
114+
head_tolerations: Optional[List[V1Toleration]] = None
115+
worker_cpu_requests: Union[int, str] = 1
116+
worker_cpu_limits: Union[int, str] = 1
117+
num_workers: int = 1
118+
worker_memory_requests: Union[int, str] = 2
119+
worker_memory_limits: Union[int, str] = 2
120+
worker_tolerations: Optional[List[V1Toleration]] = None
121+
appwrapper: bool = False
122+
envs: Dict[str, str] = field(default_factory=dict)
123+
image: str = ""
124+
image_pull_secrets: List[str] = field(default_factory=list)
125+
write_to_file: bool = False
126+
verify_tls: bool = True
127+
labels: Dict[str, str] = field(default_factory=dict)
128+
worker_extended_resource_requests: Dict[str, Union[str, int]] = field(
129+
default_factory=dict
130+
)
131+
extended_resource_mapping: Dict[str, str] = field(default_factory=dict)
132+
overwrite_default_resource_mapping: bool = False
133+
local_queue: Optional[str] = None
134+
annotations: Dict[str, str] = field(default_factory=dict)
135+
volumes: list[V1Volume] = field(default_factory=list)
136+
volume_mounts: list[V1VolumeMount] = field(default_factory=list)
137+
enable_gcs_ft: bool = False
138+
enable_usage_stats: bool = False
139+
redis_address: Optional[str] = None
140+
redis_password_secret: Optional[Dict[str, str]] = None
141+
external_storage_namespace: Optional[str] = None
142+
143+
def __post_init__(self):
144+
if not self.verify_tls:
145+
print(
146+
"Warning: TLS verification has been disabled - Endpoint checks will be bypassed"
147+
)
148+
149+
if self.enable_usage_stats:
150+
self.envs["RAY_USAGE_STATS_ENABLED"] = "1"
151+
else:
152+
self.envs["RAY_USAGE_STATS_ENABLED"] = "0"
153+
154+
if self.enable_gcs_ft:
155+
if not self.redis_address:
156+
raise ValueError(
157+
"redis_address must be provided when enable_gcs_ft is True"
158+
)
159+
160+
if self.redis_password_secret and not isinstance(
161+
self.redis_password_secret, dict
162+
):
163+
raise ValueError(
164+
"redis_password_secret must be a dictionary with 'name' and 'key' fields"
165+
)
166+
167+
if self.redis_password_secret and (
168+
"name" not in self.redis_password_secret
169+
or "key" not in self.redis_password_secret
170+
):
171+
raise ValueError(
172+
"redis_password_secret must contain both 'name' and 'key' fields"
173+
)
174+
175+
self._validate_types()
176+
self._memory_to_resource()
177+
self._memory_to_string()
178+
self._str_mem_no_unit_add_GB()
179+
self._cpu_to_resource()
180+
self._combine_extended_resource_mapping()
181+
self._validate_extended_resource_requests(self.head_extended_resource_requests)
182+
self._validate_extended_resource_requests(
183+
self.worker_extended_resource_requests
184+
)
185+
186+
def _combine_extended_resource_mapping(self):
187+
if overwritten := set(self.extended_resource_mapping.keys()).intersection(
188+
DEFAULT_RESOURCE_MAPPING.keys()
189+
):
190+
if self.overwrite_default_resource_mapping:
191+
warnings.warn(
192+
f"Overwriting default resource mapping for {overwritten}",
193+
UserWarning,
194+
)
195+
else:
196+
raise ValueError(
197+
f"Resource mapping already exists for {overwritten}, set overwrite_default_resource_mapping to True to overwrite"
198+
)
199+
self.extended_resource_mapping = {
200+
**DEFAULT_RESOURCE_MAPPING,
201+
**self.extended_resource_mapping,
202+
}
203+
204+
def _validate_extended_resource_requests(self, extended_resources: Dict[str, int]):
205+
for k in extended_resources.keys():
206+
if k not in self.extended_resource_mapping.keys():
207+
raise ValueError(
208+
f"extended resource '{k}' not found in extended_resource_mapping, available resources are {list(self.extended_resource_mapping.keys())}, to add more supported resources use extended_resource_mapping. i.e. extended_resource_mapping = {{'{k}': 'FOO_BAR'}}"
209+
)
210+
211+
def _str_mem_no_unit_add_GB(self):
212+
if isinstance(self.head_memory, str) and self.head_memory.isdecimal():
213+
self.head_memory = f"{self.head_memory}G"
214+
if (
215+
isinstance(self.worker_memory_requests, str)
216+
and self.worker_memory_requests.isdecimal()
217+
):
218+
self.worker_memory_requests = f"{self.worker_memory_requests}G"
219+
if (
220+
isinstance(self.worker_memory_limits, str)
221+
and self.worker_memory_limits.isdecimal()
222+
):
223+
self.worker_memory_limits = f"{self.worker_memory_limits}G"
224+
225+
def _memory_to_string(self):
226+
if isinstance(self.head_memory_requests, int):
227+
self.head_memory_requests = f"{self.head_memory_requests}G"
228+
if isinstance(self.head_memory_limits, int):
229+
self.head_memory_limits = f"{self.head_memory_limits}G"
230+
if isinstance(self.worker_memory_requests, int):
231+
self.worker_memory_requests = f"{self.worker_memory_requests}G"
232+
if isinstance(self.worker_memory_limits, int):
233+
self.worker_memory_limits = f"{self.worker_memory_limits}G"
234+
235+
def _cpu_to_resource(self):
236+
if self.head_cpus:
237+
warnings.warn(
238+
"head_cpus is being deprecated, use head_cpu_requests and head_cpu_limits"
239+
)
240+
self.head_cpu_requests = self.head_cpu_limits = self.head_cpus
241+
242+
def _memory_to_resource(self):
243+
if self.head_memory:
244+
warnings.warn(
245+
"head_memory is being deprecated, use head_memory_requests and head_memory_limits"
246+
)
247+
self.head_memory_requests = self.head_memory_limits = self.head_memory
248+
249+
def _validate_types(self):
250+
"""Validate the types of all fields in the ClusterConfigurationV2 dataclass."""
251+
errors = []
252+
for field_info in fields(self):
253+
value = getattr(self, field_info.name)
254+
expected_type = field_info.type
255+
if not self._is_type(value, expected_type):
256+
errors.append(f"'{field_info.name}' should be of type {expected_type}.")
257+
258+
if errors:
259+
raise TypeError("Type validation failed:\n" + "\n".join(errors))
260+
261+
@staticmethod
262+
def _is_type(value, expected_type):
263+
"""Check if the value matches the expected type."""
264+
265+
def check_type(value, expected_type):
266+
origin_type = get_origin(expected_type)
267+
args = get_args(expected_type)
268+
if origin_type is Union:
269+
return any(check_type(value, union_type) for union_type in args)
270+
if origin_type is list:
271+
if value is not None:
272+
return all(check_type(elem, args[0]) for elem in (value or []))
273+
else:
274+
return True
275+
if origin_type is dict:
276+
if value is not None:
277+
return all(
278+
check_type(k, args[0]) and check_type(v, args[1])
279+
for k, v in value.items()
280+
)
281+
else:
282+
return True
283+
if origin_type is tuple:
284+
return all(check_type(elem, etype) for elem, etype in zip(value, args))
285+
if expected_type is int:
286+
return isinstance(value, int) and not isinstance(value, bool)
287+
if expected_type is bool:
288+
return isinstance(value, bool)
289+
return isinstance(value, expected_type)
290+
291+
return check_type(value, expected_type)

src/codeflare_sdk/ray/rayjobs/rayjob.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,9 @@ def _build_ray_cluster_spec(self) -> Dict[str, Any]:
224224
# Note: CodeFlare Operator should still create dashboard routes for the RayCluster
225225
ray_cluster_spec = ray_cluster_dict["spec"]
226226

227+
# Override imagePullPolicy to "IfNotPresent" for RayJob RayCluster specs
228+
self._override_image_pull_policy(ray_cluster_spec)
229+
227230
logger.info(
228231
f"Built RayCluster spec using existing build logic for cluster: {self.cluster_name}"
229232
)
@@ -298,3 +301,31 @@ def _map_to_codeflare_status(
298301
return status_mapping.get(
299302
deployment_status, (CodeflareRayJobStatus.UNKNOWN, False)
300303
)
304+
305+
def _override_image_pull_policy(self, ray_cluster_spec: Dict[str, Any]) -> None:
306+
"""
307+
Override the imagePullPolicy to "IfNotPresent" for all containers in the RayCluster spec.
308+
This is specifically for RayJob RayCluster specs as requested.
309+
"""
310+
# Update head group containers
311+
if (
312+
"headGroupSpec" in ray_cluster_spec
313+
and "template" in ray_cluster_spec["headGroupSpec"]
314+
):
315+
head_template = ray_cluster_spec["headGroupSpec"]["template"]
316+
if "spec" in head_template and "containers" in head_template["spec"]:
317+
for container in head_template["spec"]["containers"]:
318+
container["imagePullPolicy"] = "IfNotPresent"
319+
320+
# Update worker group containers
321+
if "workerGroupSpecs" in ray_cluster_spec:
322+
for worker_group in ray_cluster_spec["workerGroupSpecs"]:
323+
if "template" in worker_group and "spec" in worker_group["template"]:
324+
worker_template = worker_group["template"]
325+
if "containers" in worker_template["spec"]:
326+
for container in worker_template["spec"]["containers"]:
327+
container["imagePullPolicy"] = "IfNotPresent"
328+
329+
logger.debug(
330+
"Updated imagePullPolicy to 'IfNotPresent' for all containers in RayCluster spec"
331+
)

0 commit comments

Comments
 (0)