Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 41 additions & 8 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ cryptography = "43.0.3"
executing = "1.2.0"
pydantic = "< 2"
ipywidgets = "8.1.2"
python-client = { git = "https://github.com/ray-project/kuberay.git", subdirectory = "clients/python-client", rev = "d1e750d9beac612ad455b951c1a789f971409ab3" }
python-client = { git = "https://github.com/ray-project/kuberay.git", subdirectory = "clients/python-client", rev = "b2fd91b58c2bbe22f9b4f730c5a8f3180c05e570" }

[[tool.poetry.source]]
name = "pypi"
Expand All @@ -59,6 +59,10 @@ pytest-mock = "3.11.1"
pytest-timeout = "2.3.1"
jupyterlab = "4.3.1"


[tool.poetry.group.dev.dependencies]
diff-cover = "^9.6.0"

[tool.pytest.ini_options]
filterwarnings = [
"ignore::DeprecationWarning:pkg_resources",
Expand Down
10 changes: 9 additions & 1 deletion src/codeflare_sdk/ray/rayjobs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,15 @@ def build_script_configmap_spec(
return {
"apiVersion": "v1",
"kind": "ConfigMap",
"metadata": {"name": configmap_name, "namespace": namespace},
"metadata": {
"name": configmap_name,
"namespace": namespace,
"labels": {
"ray.io/job-name": job_name,
"app.kubernetes.io/managed-by": "codeflare-sdk",
"app.kubernetes.io/component": "rayjob-scripts",
},
},
"data": scripts,
}

Expand Down
112 changes: 86 additions & 26 deletions src/codeflare_sdk/ray/rayjobs/rayjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,33 +154,29 @@ def __init__(
logger.info(f"Initialized RayJob: {self.name} in namespace: {self.namespace}")

def submit(self) -> str:
# Validate required parameters
if not self.entrypoint:
raise ValueError("entrypoint must be provided to submit a RayJob")
raise ValueError("Entrypoint must be provided to submit a RayJob")

# Validate Ray version compatibility for both cluster_config and runtime_env
self._validate_ray_version_compatibility()
# Automatically handle script files for new clusters
if self._cluster_config is not None:
scripts = self._extract_script_files_from_entrypoint()
if scripts:
self._handle_script_volumes_for_new_cluster(scripts)

# Handle script files for existing clusters
elif self._cluster_name:
scripts = self._extract_script_files_from_entrypoint()
if scripts:
self._handle_script_volumes_for_existing_cluster(scripts)

# Build the RayJob custom resource
rayjob_cr = self._build_rayjob_cr()

# Submit the job - KubeRay operator handles everything else
logger.info(f"Submitting RayJob {self.name} to KubeRay operator")
logger.info(f"Submitting RayJob {self.name} to Kuberay operator")
result = self._api.submit_job(k8s_namespace=self.namespace, job=rayjob_cr)

if result:
logger.info(f"Successfully submitted RayJob {self.name}")

# Handle script files after RayJob creation so we can set owner reference
if self._cluster_config is not None:
scripts = self._extract_script_files_from_entrypoint()
if scripts:
self._handle_script_volumes_for_new_cluster(scripts, result)
elif self._cluster_name:
scripts = self._extract_script_files_from_entrypoint()
if scripts:
self._handle_script_volumes_for_existing_cluster(scripts, result)

if self.shutdown_after_job_finishes:
logger.info(
f"Cluster will be automatically cleaned up {self.ttl_seconds_after_finished}s after job completion"
Expand All @@ -189,11 +185,42 @@ def submit(self) -> str:
else:
raise RuntimeError(f"Failed to submit RayJob {self.name}")

def stop(self):
"""
Suspend the Ray job.
"""
stopped = self._api.suspend_job(name=self.name, k8s_namespace=self.namespace)
if stopped:
logger.info(f"Successfully stopped the RayJob {self.name}")
return True
else:
raise RuntimeError(f"Failed to stop the RayJob {self.name}")

def resubmit(self):
"""
Resubmit the Ray job.
"""
if self._api.resubmit_job(name=self.name, k8s_namespace=self.namespace):
logger.info(f"Successfully resubmitted the RayJob {self.name}")
return True
else:
raise RuntimeError(f"Failed to resubmit the RayJob {self.name}")

def delete(self):
"""
Delete the Ray job.
"""
deleted = self._api.delete_job(name=self.name, k8s_namespace=self.namespace)
if deleted:
logger.info(f"Successfully deleted the RayJob {self.name}")
return True
else:
raise RuntimeError(f"Failed to delete the RayJob {self.name}")

def _build_rayjob_cr(self) -> Dict[str, Any]:
"""
Build the RayJob custom resource specification using native RayJob capabilities.
"""
# Basic RayJob custom resource structure
rayjob_cr = {
"apiVersion": "ray.io/v1",
"kind": "RayJob",
Expand Down Expand Up @@ -449,7 +476,9 @@ def _find_local_imports(
except (SyntaxError, ValueError) as e:
logger.debug(f"Could not parse imports from {script_path}: {e}")

def _handle_script_volumes_for_new_cluster(self, scripts: Dict[str, str]):
def _handle_script_volumes_for_new_cluster(
self, scripts: Dict[str, str], rayjob_result: Dict[str, Any] = None
):
"""Handle script volumes for new clusters (uses ManagedClusterConfig)."""
# Validate ConfigMap size before creation
self._cluster_config.validate_configmap_size(scripts)
Expand All @@ -459,15 +488,17 @@ def _handle_script_volumes_for_new_cluster(self, scripts: Dict[str, str]):
job_name=self.name, namespace=self.namespace, scripts=scripts
)

# Create ConfigMap via Kubernetes API
configmap_name = self._create_configmap_from_spec(configmap_spec)
# Create ConfigMap via Kubernetes API with owner reference
configmap_name = self._create_configmap_from_spec(configmap_spec, rayjob_result)

# Add volumes to cluster config (config.py handles spec building)
self._cluster_config.add_script_volumes(
configmap_name=configmap_name, mount_path=MOUNT_PATH
)

def _handle_script_volumes_for_existing_cluster(self, scripts: Dict[str, str]):
def _handle_script_volumes_for_existing_cluster(
self, scripts: Dict[str, str], rayjob_result: Dict[str, Any] = None
):
"""Handle script volumes for existing clusters (updates RayCluster CR)."""
# Create config builder for utility methods
config_builder = ManagedClusterConfig()
Expand All @@ -480,28 +511,57 @@ def _handle_script_volumes_for_existing_cluster(self, scripts: Dict[str, str]):
job_name=self.name, namespace=self.namespace, scripts=scripts
)

# Create ConfigMap via Kubernetes API
configmap_name = self._create_configmap_from_spec(configmap_spec)
# Create ConfigMap via Kubernetes API with owner reference
configmap_name = self._create_configmap_from_spec(configmap_spec, rayjob_result)

# Update existing RayCluster
self._update_existing_cluster_for_scripts(configmap_name, config_builder)

def _create_configmap_from_spec(self, configmap_spec: Dict[str, Any]) -> str:
def _create_configmap_from_spec(
self, configmap_spec: Dict[str, Any], rayjob_result: Dict[str, Any] = None
) -> str:
"""
Create ConfigMap from specification via Kubernetes API.

Args:
configmap_spec: ConfigMap specification dictionary
rayjob_result: The result from RayJob creation containing UID

Returns:
str: Name of the created ConfigMap
"""

configmap_name = configmap_spec["metadata"]["name"]

metadata = client.V1ObjectMeta(**configmap_spec["metadata"])

# Add owner reference if we have the RayJob result
if (
rayjob_result
and isinstance(rayjob_result, dict)
and rayjob_result.get("metadata", {}).get("uid")
):
logger.info(
f"Adding owner reference to ConfigMap '{configmap_name}' with RayJob UID: {rayjob_result['metadata']['uid']}"
)
metadata.owner_references = [
client.V1OwnerReference(
api_version="ray.io/v1",
kind="RayJob",
name=self.name,
uid=rayjob_result["metadata"]["uid"],
controller=True,
block_owner_deletion=True,
)
]
else:
logger.warning(
f"No valid RayJob result with UID found, ConfigMap '{configmap_name}' will not have owner reference. Result: {rayjob_result}"
)

# Convert dict spec to V1ConfigMap
configmap = client.V1ConfigMap(
metadata=client.V1ObjectMeta(**configmap_spec["metadata"]),
metadata=metadata,
data=configmap_spec["data"],
)

Expand Down
57 changes: 57 additions & 0 deletions src/codeflare_sdk/ray/rayjobs/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,40 @@ def test_gpu_validation_fails_with_unsupported_accelerator():
ManagedClusterConfig(head_accelerators={"unsupported.com/accelerator": 1})


def test_config_type_validation_errors(mocker):
"""Test that type validation properly raises errors with incorrect types."""
# Mock the _is_type method to return False for type checking
mocker.patch.object(
ManagedClusterConfig,
"_is_type",
side_effect=lambda value, expected_type: False, # Always fail type check
)

# This should raise TypeError during initialization
with pytest.raises(TypeError, match="Type validation failed"):
ManagedClusterConfig()


def test_config_is_type_method():
"""Test the _is_type static method for type checking."""
# Test basic types
assert ManagedClusterConfig._is_type("test", str) is True
assert ManagedClusterConfig._is_type(123, int) is True
assert ManagedClusterConfig._is_type(123, str) is False

# Test optional types (Union with None)
from typing import Optional

assert ManagedClusterConfig._is_type(None, Optional[str]) is True
assert ManagedClusterConfig._is_type("test", Optional[str]) is True
assert ManagedClusterConfig._is_type(123, Optional[str]) is False

# Test dict types
assert ManagedClusterConfig._is_type({}, dict) is True
assert ManagedClusterConfig._is_type({"key": "value"}, dict) is True
assert ManagedClusterConfig._is_type([], dict) is False


def test_ray_usage_stats_always_disabled_by_default():
"""Test that RAY_USAGE_STATS_ENABLED is always set to '0' by default"""
config = ManagedClusterConfig()
Expand Down Expand Up @@ -170,3 +204,26 @@ def test_add_script_volumes_existing_mount_early_return():
# Should still have only one mount, no volume added
assert len(config.volumes) == 0
assert len(config.volume_mounts) == 1


def test_build_script_configmap_spec_labels():
"""Test that build_script_configmap_spec creates ConfigMap with correct labels."""
config = ManagedClusterConfig()

job_name = "test-job"
namespace = "test-namespace"
scripts = {"script.py": "print('hello')", "helper.py": "# helper code"}

configmap_spec = config.build_script_configmap_spec(job_name, namespace, scripts)

assert configmap_spec["apiVersion"] == "v1"
assert configmap_spec["kind"] == "ConfigMap"
assert configmap_spec["metadata"]["name"] == f"{job_name}-scripts"
assert configmap_spec["metadata"]["namespace"] == namespace

labels = configmap_spec["metadata"]["labels"]
assert labels["ray.io/job-name"] == job_name
assert labels["app.kubernetes.io/managed-by"] == "codeflare-sdk"
assert labels["app.kubernetes.io/component"] == "rayjob-scripts"

assert configmap_spec["data"] == scripts
Loading
Loading