Skip to content

Commit 28e94c5

Browse files
committed
task(RHOAIENG-33283): Utilize runtime_env from Ray Public API
1 parent d418d24 commit 28e94c5

File tree

4 files changed

+53
-29
lines changed

4 files changed

+53
-29
lines changed

src/codeflare_sdk/ray/rayjobs/rayjob.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import logging
2020
import warnings
2121
from typing import Dict, Any, Optional, Tuple
22+
23+
from ray.runtime_env import RuntimeEnv
2224
from codeflare_sdk.common.kueue.kueue import get_default_kueue_name
2325
from codeflare_sdk.common.utils.constants import MOUNT_PATH
2426

@@ -61,7 +63,7 @@ def __init__(
6163
cluster_name: Optional[str] = None,
6264
cluster_config: Optional[ManagedClusterConfig] = None,
6365
namespace: Optional[str] = None,
64-
runtime_env: Optional[Dict[str, Any]] = None,
66+
runtime_env: Optional[RuntimeEnv] = None,
6567
ttl_seconds_after_finished: int = 0,
6668
active_deadline_seconds: Optional[int] = None,
6769
local_queue: Optional[str] = None,
@@ -75,7 +77,7 @@ def __init__(
7577
cluster_name: The name of an existing Ray cluster (optional if cluster_config provided)
7678
cluster_config: Configuration for creating a new cluster (optional if cluster_name provided)
7779
namespace: The Kubernetes namespace (auto-detected if not specified)
78-
runtime_env: Ray runtime environment configuration (optional)
80+
runtime_env: Ray runtime environment configuration as RuntimeEnv object (optional)
7981
ttl_seconds_after_finished: Seconds to wait before cleanup after job finishes (default: 0)
8082
active_deadline_seconds: Maximum time the job can run before being terminated (optional)
8183
local_queue: The Kueue LocalQueue to submit the job to (optional)

src/codeflare_sdk/ray/rayjobs/runtime_env.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Dict, Any, Optional, List, TYPE_CHECKING
99
from codeflare_sdk.common.utils.constants import MOUNT_PATH
1010
from kubernetes import client
11+
from ray.runtime_env import RuntimeEnv
1112

1213
from codeflare_sdk.ray.rayjobs.config import ManagedClusterConfig
1314
from ...common.kubernetes_cluster.auth import get_api_client
@@ -22,6 +23,14 @@
2223
PYTHON_FILE_PATTERN = r"(?:python\s+)?([./\w/]+\.py)"
2324

2425

26+
def _normalize_runtime_env(
27+
runtime_env: Optional[RuntimeEnv],
28+
) -> Optional[Dict[str, Any]]:
29+
if runtime_env is None:
30+
return None
31+
return runtime_env.to_dict()
32+
33+
2534
def extract_all_local_files(job: RayJob) -> Optional[Dict[str, str]]:
2635
"""
2736
Extract all local files from both entrypoint and runtime_env working_dir.
@@ -32,14 +41,17 @@ def extract_all_local_files(job: RayJob) -> Optional[Dict[str, str]]:
3241
Returns:
3342
Dict of {file_name: file_content} if local files found, None otherwise
3443
"""
44+
# Convert RuntimeEnv to dict for processing
45+
runtime_env_dict = _normalize_runtime_env(job.runtime_env)
46+
3547
# If there's a remote working_dir, don't extract local files to avoid conflicts
3648
if (
37-
job.runtime_env
38-
and "working_dir" in job.runtime_env
39-
and not os.path.isdir(job.runtime_env["working_dir"])
49+
runtime_env_dict
50+
and "working_dir" in runtime_env_dict
51+
and not os.path.isdir(runtime_env_dict["working_dir"])
4052
):
4153
logger.info(
42-
f"Remote working_dir detected: {job.runtime_env['working_dir']}. "
54+
f"Remote working_dir detected: {runtime_env_dict['working_dir']}. "
4355
"Skipping local file extraction - all files should come from remote source."
4456
)
4557
return None
@@ -55,18 +67,18 @@ def extract_all_local_files(job: RayJob) -> Optional[Dict[str, str]]:
5567

5668
# Extract files from runtime_env working_dir if it's a local directory
5769
if (
58-
job.runtime_env
59-
and "working_dir" in job.runtime_env
60-
and os.path.isdir(job.runtime_env["working_dir"])
70+
runtime_env_dict
71+
and "working_dir" in runtime_env_dict
72+
and os.path.isdir(runtime_env_dict["working_dir"])
6173
):
6274
working_dir_files = extract_working_dir_files(
63-
job.runtime_env["working_dir"], processed_files
75+
runtime_env_dict["working_dir"], processed_files
6476
)
6577
if working_dir_files:
6678
files.update(working_dir_files)
6779

6880
# If no working_dir specified in runtime_env, try to infer and extract files from inferred directory
69-
elif not job.runtime_env or "working_dir" not in job.runtime_env:
81+
elif not runtime_env_dict or "working_dir" not in runtime_env_dict:
7082
inferred_working_dir = infer_working_dir_from_entrypoint(job)
7183
if inferred_working_dir:
7284
working_dir_files = extract_working_dir_files(
@@ -221,24 +233,27 @@ def process_runtime_env(
221233
Returns:
222234
Processed runtime environment as YAML string, or None if no processing needed
223235
"""
236+
# Convert RuntimeEnv to dict for processing
237+
runtime_env_dict = _normalize_runtime_env(job.runtime_env)
238+
224239
processed_env = {}
225240

226241
# Handle env_vars
227-
if job.runtime_env and "env_vars" in job.runtime_env:
228-
processed_env["env_vars"] = job.runtime_env["env_vars"]
242+
if runtime_env_dict and "env_vars" in runtime_env_dict:
243+
processed_env["env_vars"] = runtime_env_dict["env_vars"]
229244
logger.info(
230-
f"Added {len(job.runtime_env['env_vars'])} environment variables to runtime_env"
245+
f"Added {len(runtime_env_dict['env_vars'])} environment variables to runtime_env"
231246
)
232247

233248
# Handle pip dependencies
234-
if job.runtime_env and "pip" in job.runtime_env:
235-
pip_deps = process_pip_dependencies(job, job.runtime_env["pip"])
249+
if runtime_env_dict and "pip" in runtime_env_dict:
250+
pip_deps = process_pip_dependencies(job, runtime_env_dict["pip"])
236251
if pip_deps:
237252
processed_env["pip"] = pip_deps
238253

239254
# Handle working_dir - if it's a local path, set it to mount path
240-
if job.runtime_env and "working_dir" in job.runtime_env:
241-
working_dir = job.runtime_env["working_dir"]
255+
if runtime_env_dict and "working_dir" in runtime_env_dict:
256+
working_dir = runtime_env_dict["working_dir"]
242257
if os.path.isdir(working_dir):
243258
# Local working directory - will be mounted at MOUNT_PATH
244259
processed_env["working_dir"] = MOUNT_PATH
@@ -252,7 +267,7 @@ def process_runtime_env(
252267
logger.info(f"Using remote working directory: {working_dir}")
253268

254269
# If no working_dir specified but we have files, set working_dir to mount path
255-
elif not job.runtime_env or "working_dir" not in job.runtime_env:
270+
elif not runtime_env_dict or "working_dir" not in runtime_env_dict:
256271
if files:
257272
# Local files found - will be mounted at MOUNT_PATH
258273
processed_env["working_dir"] = MOUNT_PATH

src/codeflare_sdk/ray/rayjobs/test/test_rayjob.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import pytest
1616
from unittest.mock import MagicMock
1717
from codeflare_sdk.common.utils.constants import RAY_VERSION
18+
from ray.runtime_env import RuntimeEnv
1819

1920
from codeflare_sdk.ray.rayjobs.rayjob import RayJob
2021
from codeflare_sdk.ray.cluster.config import ClusterConfiguration
@@ -35,7 +36,7 @@ def test_rayjob_submit_success(auto_mock_setup):
3536
cluster_name="test-ray-cluster",
3637
namespace="test-namespace",
3738
entrypoint="python -c 'print(\"hello world\")'",
38-
runtime_env={"pip": ["requests"]},
39+
runtime_env=RuntimeEnv(pip=["requests"]),
3940
)
4041

4142
job_id = rayjob.submit()
@@ -68,7 +69,7 @@ def test_rayjob_submit_failure(auto_mock_setup):
6869
cluster_name="test-ray-cluster",
6970
namespace="default",
7071
entrypoint="python test.py",
71-
runtime_env={"pip": ["numpy"]},
72+
runtime_env=RuntimeEnv(pip=["numpy"]),
7273
)
7374

7475
with pytest.raises(RuntimeError, match="Failed to submit RayJob test-rayjob"):
@@ -474,7 +475,7 @@ def test_rayjob_with_runtime_env(auto_mock_setup):
474475
"""
475476
Test RayJob with runtime environment configuration.
476477
"""
477-
runtime_env = {"pip": ["numpy", "pandas"]}
478+
runtime_env = RuntimeEnv(pip=["numpy", "pandas"])
478479

479480
rayjob = RayJob(
480481
job_name="test-job",
@@ -567,7 +568,7 @@ def test_rayjob_constructor_parameter_validation(auto_mock_setup):
567568
entrypoint="python test.py",
568569
cluster_name="test-cluster",
569570
namespace="test-ns",
570-
runtime_env={"pip": ["numpy"]},
571+
runtime_env=RuntimeEnv(pip=["numpy"]),
571572
ttl_seconds_after_finished=300,
572573
active_deadline_seconds=600,
573574
)
@@ -576,7 +577,12 @@ def test_rayjob_constructor_parameter_validation(auto_mock_setup):
576577
assert rayjob.entrypoint == "python test.py"
577578
assert rayjob.cluster_name == "test-cluster"
578579
assert rayjob.namespace == "test-ns"
579-
assert rayjob.runtime_env == {"pip": ["numpy"]}
580+
# Check that runtime_env is a RuntimeEnv object and contains pip dependencies
581+
assert isinstance(rayjob.runtime_env, RuntimeEnv)
582+
runtime_env_dict = rayjob.runtime_env.to_dict()
583+
assert "pip" in runtime_env_dict
584+
# Ray transforms pip to dict format with 'packages' key
585+
assert runtime_env_dict["pip"]["packages"] == ["numpy"]
580586
assert rayjob.ttl_seconds_after_finished == 300
581587
assert rayjob.active_deadline_seconds == 600
582588

src/codeflare_sdk/ray/rayjobs/test/test_runtime_env.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
from unittest.mock import MagicMock, patch
1919
from codeflare_sdk.common.utils.constants import MOUNT_PATH, RAY_VERSION
20+
from ray.runtime_env import RuntimeEnv
2021

2122
from codeflare_sdk.ray.rayjobs.rayjob import RayJob
2223
from codeflare_sdk.ray.cluster.config import ClusterConfiguration
@@ -43,11 +44,11 @@ def test_rayjob_with_remote_working_dir(auto_mock_setup):
4344
Test RayJob with remote working directory in runtime_env.
4445
Should not extract local files and should pass through remote URL.
4546
"""
46-
runtime_env = {
47-
"working_dir": "https://github.com/org/repo/archive/refs/heads/main.zip",
48-
"pip": ["numpy", "pandas"],
49-
"env_vars": {"TEST_VAR": "test_value"},
50-
}
47+
runtime_env = RuntimeEnv(
48+
working_dir="https://github.com/org/repo/archive/refs/heads/main.zip",
49+
pip=["numpy", "pandas"],
50+
env_vars={"TEST_VAR": "test_value"},
51+
)
5152

5253
rayjob = RayJob(
5354
job_name="test-job",

0 commit comments

Comments
 (0)