Skip to content

Commit 8920adb

Browse files
committed
Changes as per review
Signed-off-by: Pat O'Connor <[email protected]>
1 parent 430eefb commit 8920adb

File tree

3 files changed

+42
-30
lines changed

3 files changed

+42
-30
lines changed

src/codeflare_sdk/ray/rayjobs/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,13 @@ def add_script_volumes(
559559
logger.info(
560560
f"Added script volume '{configmap_name}' to cluster config: mount_path={mount_path}"
561561
)
562+
563+
def validate_configmap_size(self, scripts: Dict[str, str]) -> None:
564+
total_size = sum(len(content.encode("utf-8")) for content in scripts.values())
565+
if total_size > 1024 * 1024: # 1MB
566+
raise ValueError(
567+
f"ConfigMap size exceeds 1MB limit. Total size: {total_size} bytes"
568+
)
562569

563570
def build_script_configmap_spec(
564571
self, job_name: str, namespace: str, scripts: Dict[str, str]

src/codeflare_sdk/ray/rayjobs/rayjob.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from kubernetes import client
2525
from ...common.kubernetes_cluster.auth import get_api_client
2626
from python_client.kuberay_job_api import RayjobApi
27-
27+
from python_client.kuberay_cluster_api import RayClusterApi
2828
from codeflare_sdk.ray.rayjobs.config import ManagedClusterConfig
2929

3030
from ...common.utils import get_current_namespace
@@ -39,6 +39,8 @@
3939

4040
logger = logging.getLogger(__name__)
4141

42+
mount_path = "/home/ray/scripts"
43+
4244

4345
class RayJob:
4446
"""
@@ -146,6 +148,7 @@ def __init__(
146148
logger.info(f"Using existing cluster: {self.cluster_name}")
147149

148150
self._api = RayjobApi()
151+
self._cluster_api = RayClusterApi()
149152

150153
logger.info(f"Initialized RayJob: {self.name} in namespace: {self.namespace}")
151154

@@ -311,7 +314,7 @@ def _extract_script_files_from_entrypoint(self) -> Optional[Dict[str, str]]:
311314
return None
312315

313316
scripts = {}
314-
mount_path = "/home/ray/scripts"
317+
# mount_path = "/home/ray/scripts"
315318
processed_files = set() # Avoid infinite loops
316319

317320
# Look for Python file patterns in entrypoint (e.g., "python script.py", "python /path/to/script.py")
@@ -410,6 +413,9 @@ def _find_local_imports(
410413

411414
def _handle_script_volumes_for_new_cluster(self, scripts: Dict[str, str]):
412415
"""Handle script volumes for new clusters (uses ManagedClusterConfig)."""
416+
# Validate ConfigMap size before creation
417+
self._cluster_config.validate_configmap_size(scripts)
418+
413419
# Build ConfigMap spec using config.py
414420
configmap_spec = self._cluster_config.build_script_configmap_spec(
415421
job_name=self.name, namespace=self.namespace, scripts=scripts
@@ -427,6 +433,9 @@ def _handle_script_volumes_for_existing_cluster(self, scripts: Dict[str, str]):
427433
"""Handle script volumes for existing clusters (updates RayCluster CR)."""
428434
# Create config builder for utility methods
429435
config_builder = ManagedClusterConfig()
436+
437+
# Validate ConfigMap size before creation
438+
config_builder.validate_configmap_size(scripts)
430439

431440
# Build ConfigMap spec using config.py
432441
configmap_spec = config_builder.build_script_configmap_spec(
@@ -495,12 +504,9 @@ def _update_existing_cluster_for_scripts(
495504
# Get existing RayCluster
496505
api_instance = client.CustomObjectsApi(get_api_client())
497506
try:
498-
ray_cluster = api_instance.get_namespaced_custom_object(
499-
group="ray.io",
500-
version="v1",
501-
namespace=self.namespace,
502-
plural="rayclusters",
507+
ray_cluster = self._cluster_api.get_ray_cluster(
503508
name=self.cluster_name,
509+
k8s_namespace=self.namespace,
504510
)
505511
except client.ApiException as e:
506512
raise RuntimeError(f"Failed to get RayCluster '{self.cluster_name}': {e}")
@@ -546,13 +552,10 @@ def mount_exists(mounts_list, mount_name):
546552

547553
# Update the RayCluster
548554
try:
549-
api_instance.patch_namespaced_custom_object(
550-
group="ray.io",
551-
version="v1",
552-
namespace=self.namespace,
553-
plural="rayclusters",
555+
self._cluster_api.patch_ray_cluster(
554556
name=self.cluster_name,
555-
body=ray_cluster,
557+
ray_cluster=ray_cluster,
558+
k8s_namespace=self.namespace,
556559
)
557560
logger.info(
558561
f"Updated RayCluster '{self.cluster_name}' with script volumes from ConfigMap '{configmap_name}'"

src/codeflare_sdk/ray/rayjobs/test_rayjob.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ def test_rayjob_submit_success(mocker):
3131
mock_api_instance = MagicMock()
3232
mock_api_class.return_value = mock_api_instance
3333

34+
# Mock the RayClusterApi class
35+
mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayClusterApi")
36+
3437
# Configure the mock to return success when submit is called
3538
mock_api_instance.submit.return_value = {"metadata": {"name": "test-rayjob"}}
3639

@@ -75,6 +78,9 @@ def test_rayjob_submit_failure(mocker):
7578
mock_api_instance = MagicMock()
7679
mock_api_class.return_value = mock_api_instance
7780

81+
# Mock the RayClusterApi class
82+
mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayClusterApi")
83+
7884
# Configure the mock to return failure (False/None) when submit_job is called
7985
mock_api_instance.submit_job.return_value = None
8086

@@ -1580,21 +1586,19 @@ def test_create_configmap_api_error_non_409(mocker):
15801586
def test_update_existing_cluster_get_cluster_error(mocker):
15811587
"""Test _update_existing_cluster_for_scripts handles get cluster errors."""
15821588
mocker.patch("kubernetes.config.load_kube_config")
1583-
mock_api_class = mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi")
1589+
mock_rayjob_api = mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi")
15841590

1585-
# Mock CustomObjectsApi with error
1586-
mock_custom_api = mocker.patch("kubernetes.client.CustomObjectsApi")
1587-
mock_api_instance = mocker.Mock()
1588-
mock_custom_api.return_value = mock_api_instance
1591+
# Mock RayClusterApi with error
1592+
mock_cluster_api_class = mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayClusterApi")
1593+
mock_cluster_api_instance = mocker.Mock()
1594+
mock_cluster_api_class.return_value = mock_cluster_api_instance
15891595

15901596
from kubernetes.client import ApiException
15911597

1592-
mock_api_instance.get_namespaced_custom_object.side_effect = ApiException(
1598+
mock_cluster_api_instance.get_ray_cluster.side_effect = ApiException(
15931599
status=404
15941600
)
15951601

1596-
mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.get_api_client")
1597-
15981602
from codeflare_sdk.ray.rayjobs.config import ManagedClusterConfig
15991603

16001604
config_builder = ManagedClusterConfig()
@@ -1614,15 +1618,15 @@ def test_update_existing_cluster_get_cluster_error(mocker):
16141618
def test_update_existing_cluster_patch_error(mocker):
16151619
"""Test _update_existing_cluster_for_scripts handles patch errors."""
16161620
mocker.patch("kubernetes.config.load_kube_config")
1617-
mock_api_class = mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi")
1621+
mock_rayjob_api = mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi")
16181622

1619-
# Mock CustomObjectsApi
1620-
mock_custom_api = mocker.patch("kubernetes.client.CustomObjectsApi")
1621-
mock_api_instance = mocker.Mock()
1622-
mock_custom_api.return_value = mock_api_instance
1623+
# Mock RayClusterApi
1624+
mock_cluster_api_class = mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayClusterApi")
1625+
mock_cluster_api_instance = mocker.Mock()
1626+
mock_cluster_api_class.return_value = mock_cluster_api_instance
16231627

16241628
# Mock successful get but failed patch
1625-
mock_api_instance.get_namespaced_custom_object.return_value = {
1629+
mock_cluster_api_instance.get_ray_cluster.return_value = {
16261630
"spec": {
16271631
"headGroupSpec": {
16281632
"template": {
@@ -1641,12 +1645,10 @@ def test_update_existing_cluster_patch_error(mocker):
16411645

16421646
from kubernetes.client import ApiException
16431647

1644-
mock_api_instance.patch_namespaced_custom_object.side_effect = ApiException(
1648+
mock_cluster_api_instance.patch_ray_cluster.side_effect = ApiException(
16451649
status=500
16461650
)
16471651

1648-
mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.get_api_client")
1649-
16501652
from codeflare_sdk.ray.rayjobs.config import ManagedClusterConfig
16511653

16521654
config_builder = ManagedClusterConfig()

0 commit comments

Comments
 (0)