Skip to content

Commit 5b48d7e

Browse files
authored
[Ray] Support new ray cluster through ray client (#2981)
1 parent 4e351dd commit 5b48d7e

File tree

2 files changed

+67
-6
lines changed

2 files changed

+67
-6
lines changed

mars/deploy/oscar/tests/test_ray.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
import asyncio
1616
import copy
1717
import os
18+
import subprocess
19+
import sys
20+
import tempfile
21+
import threading
1822
import time
1923

2024
import numpy as np
@@ -236,11 +240,55 @@ def new_ray_session_test():
236240

237241
@require_ray
238242
def test_ray_client(ray_start_regular):
239-
from ray.util.client.ray_client_helpers import ray_start_client_server
240-
from ray._private.client_mode_hook import enable_client_mode
241-
242-
with ray_start_client_server(), enable_client_mode():
243-
new_ray_session_test()
243+
server_code = """import time
244+
import ray.util.client.server.server as ray_client_server
245+
246+
server = ray_client_server.init_and_serve("{address}", num_cpus=20)
247+
print("OK", flush=True)
248+
while True:
249+
time.sleep(1)
250+
"""
251+
252+
address = "127.0.0.1:50051"
253+
254+
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
255+
f.write(server_code.format(address=address))
256+
f.flush()
257+
258+
proc = subprocess.Popen([sys.executable, "-u", f.name], stdout=subprocess.PIPE)
259+
260+
try:
261+
262+
def _check_ready(expect_exit=False):
263+
while True:
264+
line = proc.stdout.readline()
265+
if proc.returncode is not None:
266+
if expect_exit:
267+
break
268+
raise Exception(
269+
f"Failed to start ray server at {address}, "
270+
f"the return code is {proc.returncode}."
271+
)
272+
if b"OK" in line:
273+
break
274+
275+
# Avoid ray.init timeout.
276+
_check_ready()
277+
278+
# Avoid blocking the subprocess when the stdout pipe is full.
279+
t = threading.Thread(target=_check_ready, args=(True,))
280+
t.start()
281+
282+
ray.init(f"ray://{address}")
283+
ray._inside_client_test = True
284+
try:
285+
new_ray_session_test()
286+
finally:
287+
ray._inside_client_test = False
288+
ray.shutdown()
289+
finally:
290+
proc.kill()
291+
proc.wait()
244292

245293

246294
@require_ray

mars/oscar/backends/ray/communication.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import asyncio
1616
import concurrent.futures as futures
17+
import functools
1718
import itertools
1819
import logging
1920
import time
@@ -25,6 +26,7 @@
2526

2627
from ....oscar.profiling import ProfilingData
2728
from ....serialization import serialize, deserialize
29+
from ....serialization.ray import register_ray_serializers
2830
from ....metrics import Metrics
2931
from ....utils import lazy_import, implements, classproperty, Timer
3032
from ...debug import debug_async_timeout
@@ -79,7 +81,11 @@ def msg_to_simple_str(msg): # pragma: no cover
7981
return str(type(msg))
8082

8183

84+
_register_ray_serializers_once = functools.lru_cache(1)(register_ray_serializers)
85+
86+
8287
def _argwrapper_unpickler(serialized_message):
88+
_register_ray_serializers_once()
8389
return _ArgWrapper(deserialize(*serialized_message))
8490

8591

@@ -91,6 +97,7 @@ def __init__(self, message):
9197
self.message = message
9298

9399
def __reduce__(self):
100+
_register_ray_serializers_once()
94101
return _argwrapper_unpickler, (serialize(self.message),)
95102

96103

@@ -273,9 +280,15 @@ async def handle_task(message: Any, object_ref: "ray.ObjectRef"):
273280
try:
274281
result = await object_ref
275282
except Exception as e: # pragma: no cover
283+
# The error ClientObjectRef can't be formatted, so
284+
# we give it a string `ClientObjectRef` instead.
285+
try:
286+
object_ref_str = str(object_ref)
287+
except Exception:
288+
object_ref_str = "ClientObjectRef"
276289
logger.exception(
277290
"Get object %s from %s failed, got exception %s.",
278-
object_ref,
291+
object_ref_str,
279292
self.dest_address,
280293
e,
281294
)

0 commit comments

Comments
 (0)