Skip to content

Commit 73b4802

Browse files
Context manager support for RayClient (#1155)
* Context manager support for RayClient Signed-off-by: James Bourbeau <[email protected]> * Bump timeout Signed-off-by: James Bourbeau <[email protected]> --------- Signed-off-by: James Bourbeau <[email protected]> Co-authored-by: Sarah Yurick <[email protected]>
1 parent 40cb633 commit 73b4802

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

nemo_curator/core/client.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,10 @@ def stop(self) -> None:
143143
logger.info(msg)
144144
# Clear the process to prevent double execution (atexit handler)
145145
self.ray_process = None
146+
147+
def __enter__(self):
148+
self.start()
149+
return self
150+
151+
def __exit__(self, *exc):
152+
self.stop()

tests/core/test_get_ray_client.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,30 @@ def test_get_ray_client_multiple_start():
7979
os.environ["RAY_ADDRESS"] = initial_address
8080
else:
8181
os.environ.pop("RAY_ADDRESS", None)
82+
83+
84+
def wait_for_ray_cluster_start(client: RayClient, timeout: int = 30):
85+
t_start = time.perf_counter()
86+
while True:
87+
fn = os.path.join(client.ray_temp_dir, "ray_current_cluster")
88+
if os.path.exists(fn):
89+
# Cluster is up and running
90+
break
91+
elif time.perf_counter() - t_start > timeout:
92+
msg = f"Ray cluster didn't start after {timeout} seconds"
93+
raise RuntimeError(msg)
94+
else:
95+
time.sleep(1)
96+
97+
98+
def test_ray_client_context_manager(monkeypatch: pytest.MonkeyPatch):
99+
monkeypatch.delenv("RAY_ADDRESS")
100+
with tempfile.TemporaryDirectory(prefix="ray_test_ctx_manager_") as ray_tmp:
101+
with RayClient(ray_temp_dir=ray_tmp) as client:
102+
wait_for_ray_cluster_start(client)
103+
104+
with open(os.path.join(ray_tmp, "ray_current_cluster")) as f:
105+
content = f.read()
106+
assert content.split(":")[1].strip() == str(client.ray_port)
107+
108+
assert client.ray_process is None

0 commit comments

Comments
 (0)