Skip to content

Commit db16cb4

Browse files
Adds option to set object-store size when starting Ray cluster, uses a process group for managing ray child processes (#1274)
* Updates stop() to kill the ray process group created by init_cluster(). Signed-off-by: rlratzel <[email protected]> * Creates a process group when running the 'ray start' command and kills the process group in RayClient.stop() to ensure all child processes are terminated. Signed-off-by: rlratzel <[email protected]> * Adds support for creating a process group for the ray subprocesses and killing the group to better ensure all processes are killed. Signed-off-by: rlratzel <[email protected]> * Wip ab/final metrics (#996) * Adding prometheus and grafana to the nemo curator metrics path Signed-off-by: Abhinav Garg <[email protected]> * Implement safe extraction for tar files in file_utils.py - Added functions `_is_safe_path` and `tar_safe_extract` to ensure safe extraction of tar files, preventing path traversal attacks. - Included necessary imports and updated the file structure by removing outdated files from the ray-curator module. Signed-off-by: [Your Name] [[email protected]] Signed-off-by: Abhinav Garg <[email protected]> * Refactor references from ray_curator to nemo_curator across multiple files - Updated file paths and comments in api-design.md, __init__.py, client.py, and start_prometheus_grafana.py to reflect the new nemo_curator namespace. - Changed package name in package_info.py from ray_curator to nemo_curator. Signed-off-by: [Your Name] [[email protected]] Signed-off-by: Abhinav Garg <[email protected]> * Rename function `get_ray_client` to `start_prometheus_grafana` in start_prometheus_grafana.py for clarity and consistency with the new metrics path. Update the function call in the main execution block accordingly. Signed-off-by: Abhinav Garg <[email protected]> * Adding prometheus and grafana to the nemo curator metrics path Signed-off-by: Abhinav Garg <[email protected]> * Adding README for metrics Signed-off-by: Abhinav Garg <[email protected]> --------- Signed-off-by: Abhinav Garg <[email protected]> Signed-off-by: [Your Name] [[email protected]] Co-authored-by: Sarah Yurick <[email protected]> Signed-off-by: rlratzel <[email protected]> * undoes merge mistakes. Signed-off-by: rlratzel <[email protected]> * Fixes typo. Signed-off-by: rlratzel <[email protected]> * Adds OSError to exception handlers to handle process groups that have already terminated. Signed-off-by: rlratzel <[email protected]> --------- Signed-off-by: rlratzel <[email protected]> Signed-off-by: Abhinav Garg <[email protected]> Signed-off-by: [Your Name] [[email protected]] Co-authored-by: Sarah Yurick <[email protected]>
1 parent b323a3a commit db16cb4

File tree

3 files changed

+52
-18
lines changed

3 files changed

+52
-18
lines changed

nemo_curator/core/client.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import atexit
1616
import os
17+
import signal
1718
import socket
1819
import subprocess
1920
from dataclasses import dataclass, field
@@ -56,7 +57,9 @@ class RayClient:
5657
ray_dashboard_host: The host of the Ray dashboard.
5758
num_gpus: The number of GPUs to use.
5859
num_cpus: The number of CPUs to use.
60+
object_store_memory: The amount of memory to use for the object store.
5961
enable_object_spilling: Whether to enable object spilling.
62+
ray_stdouterr_capture_file: The file to capture stdout/stderr to.
6063
6164
Note:
6265
To start monitoring services (Prometheus and Grafana), use the standalone
@@ -72,18 +75,19 @@ class RayClient:
7275
ray_dashboard_host: str = DEFAULT_RAY_DASHBOARD_HOST
7376
num_gpus: int | None = None
7477
num_cpus: int | None = None
78+
object_store_memory: int | None = None
7579
enable_object_spilling: bool = False
7680
ray_stdouterr_capture_file: str | None = None
7781

78-
ray_process: subprocess.Popen | None = field(init=False)
82+
ray_process: subprocess.Popen | None = field(init=False, default=None)
7983

8084
def __post_init__(self) -> None:
8185
if self.ray_stdouterr_capture_file and os.path.exists(self.ray_stdouterr_capture_file):
8286
msg = f"Capture file {self.ray_stdouterr_capture_file} already exists."
8387
raise FileExistsError(msg)
8488

8589
def start(self) -> None:
86-
"""Start the Ray cluster, optionally capturing stdout/stderr to a file."""
90+
"""Start the Ray cluster if not already started, optionally capturing stdout/stderr to a file."""
8791
if self.include_dashboard:
8892
# Add Ray metrics service discovery to existing Prometheus configuration
8993
if is_prometheus_running() and is_grafana_running():
@@ -101,6 +105,17 @@ def start(self) -> None:
101105
)
102106
logger.warning(msg)
103107

108+
# Use the RAY_ADDRESS environment variable to determine if Ray is already running.
109+
# If a Ray cluster is not running:
110+
# RAY_ADDRESS will be set below when the Ray cluster is started and self.ray_process
111+
# will be assigned the cluster process
112+
# If a Ray cluster is already running:
113+
# RAY_ADDRESS will have been set prior to calling start(), presumably by a user starting
114+
# it externally, which means a cluster was already running and self.ray_process will be None.
115+
#
116+
# Note that the stop() method will stop the cluster only if it was started here and
117+
# self.ray_process was assigned, otherwise it leaves it running with the assumption it
118+
# was started externally and should not be stopped.
104119
if os.environ.get("RAY_ADDRESS"):
105120
logger.info("Ray is already running. Skipping the setup.")
106121
else:
@@ -119,15 +134,16 @@ def start(self) -> None:
119134
ip_address = socket.gethostbyname(socket.gethostname())
120135

121136
self.ray_process = init_cluster(
122-
self.ray_port,
123-
self.ray_temp_dir,
124-
self.ray_dashboard_port,
125-
self.ray_metrics_port,
126-
self.ray_client_server_port,
127-
self.ray_dashboard_host,
128-
self.num_gpus,
129-
self.num_cpus,
130-
self.enable_object_spilling,
137+
ray_port=self.ray_port,
138+
ray_temp_dir=self.ray_temp_dir,
139+
ray_dashboard_port=self.ray_dashboard_port,
140+
ray_metrics_port=self.ray_metrics_port,
141+
ray_client_server_port=self.ray_client_server_port,
142+
ray_dashboard_host=self.ray_dashboard_host,
143+
num_gpus=self.num_gpus,
144+
num_cpus=self.num_cpus,
145+
object_store_memory=self.object_store_memory,
146+
enable_object_spilling=self.enable_object_spilling,
131147
block=True,
132148
ip_address=ip_address,
133149
stdouterr_capture_file=self.ray_stdouterr_capture_file,
@@ -140,8 +156,21 @@ def start(self) -> None:
140156

141157
def stop(self) -> None:
142158
if self.ray_process:
143-
self.ray_process.kill()
144-
self.ray_process.wait()
159+
# Kill the entire process group to ensure child processes are terminated
160+
try:
161+
os.killpg(os.getpgid(self.ray_process.pid), signal.SIGTERM)
162+
self.ray_process.wait(timeout=5)
163+
except subprocess.TimeoutExpired:
164+
# Force kill if graceful termination doesn't work
165+
try:
166+
os.killpg(os.getpgid(self.ray_process.pid), signal.SIGKILL)
167+
self.ray_process.wait()
168+
except (ProcessLookupError, OSError):
169+
# Process group not found or process group already terminated
170+
pass
171+
except (ProcessLookupError, OSError):
172+
# Process group not found or process group already terminated
173+
pass
145174
# Reset the environment variable for RAY_ADDRESS
146175
os.environ.pop("RAY_ADDRESS", None)
147176
# Currently there is no good way of stopping a particular Ray cluster. https://github.com/ray-project/ray/issues/54989

nemo_curator/core/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def init_cluster( # noqa: PLR0913
7777
ray_dashboard_host: str,
7878
num_gpus: int | None = None,
7979
num_cpus: int | None = None,
80+
object_store_memory: int | None = None,
8081
enable_object_spilling: bool = False,
8182
block: bool = True,
8283
ip_address: str | None = None,
@@ -99,6 +100,8 @@ def init_cluster( # noqa: PLR0913
99100
ray_command.extend(["--dashboard-port", str(ray_dashboard_port)])
100101
ray_command.extend(["--ray-client-server-port", str(ray_client_server_port)])
101102
ray_command.extend(["--temp-dir", ray_temp_dir])
103+
if object_store_memory is not None:
104+
ray_command.extend(["--object-store-memory", str(object_store_memory)])
102105
ray_command.extend(["--disable-usage-stats"])
103106
if enable_object_spilling:
104107
ray_command.extend(
@@ -124,8 +127,10 @@ def init_cluster( # noqa: PLR0913
124127

125128
if stdouterr_capture_file:
126129
with open(stdouterr_capture_file, "w") as f:
127-
proc = subprocess.Popen(ray_command, shell=False, stdout=f, stderr=subprocess.STDOUT) # noqa: S603
130+
proc = subprocess.Popen( # noqa: S603
131+
ray_command, shell=False, stdout=f, stderr=subprocess.STDOUT, start_new_session=True
132+
)
128133
else:
129-
proc = subprocess.Popen(ray_command, shell=False) # noqa: S603
134+
proc = subprocess.Popen(ray_command, shell=False, start_new_session=True) # noqa: S603
130135
logger.info(f"Ray start command: {' '.join(ray_command)}")
131136
return proc

tests/core/test_get_ray_client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ def _assert_ray_stdouterr_output(stdouterr_capture_file: str) -> None:
5151
with open(stdouterr_capture_file) as f:
5252
if "Ray runtime started." in f.read():
5353
break
54-
if elapsed >= timeout:
55-
msg = f"Expected output not found in {stdouterr_capture_file} after {timeout} seconds"
56-
raise AssertionError(msg)
5754
time.sleep(1)
5855
elapsed += 1
56+
if elapsed >= timeout:
57+
msg = f"Expected output not found in {stdouterr_capture_file} after {timeout} seconds"
58+
raise AssertionError(msg)
5959

6060

6161
@pytest.fixture(scope="module")

0 commit comments

Comments
 (0)