Skip to content

Commit 52dc84e

Browse files
vmoenscursoragent
andcommitted
[Feature] Auto-batching inference server: Ray transport (#3495)
Adds RayTransport using ray.util.queue.Queue for distributed inference across Ray actors. Ray is imported lazily at instantiation time. Co-authored-by: Cursor <cursoragent@cursor.com> ghstack-source-id: 660ee98 Pull-Request: #3495 Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 7abafc5 commit 52dc84e

File tree

3 files changed

+151
-0
lines changed

3 files changed

+151
-0
lines changed

test/test_inference_server.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,16 @@
2020
InferenceServer,
2121
InferenceTransport,
2222
MPTransport,
23+
RayTransport,
2324
ThreadingTransport,
2425
)
2526

27+
_has_ray = True
28+
try:
29+
import ray
30+
except ImportError:
31+
_has_ray = False
32+
2633

2734
# =============================================================================
2835
# Helpers
@@ -398,3 +405,88 @@ def bad_model(td):
398405
td = TensorDict({"observation": torch.randn(4)})
399406
with pytest.raises(ValueError, match="mp model error"):
400407
client(td)
408+
409+
410+
# =============================================================================
411+
# Tests: RayTransport (Commit 4)
412+
# =============================================================================
413+
414+
415+
@pytest.mark.skipif(not _has_ray, reason="ray not installed")
416+
class TestRayTransport:
417+
@classmethod
418+
def setup_class(cls):
419+
if not ray.is_initialized():
420+
ray.init(num_cpus=4, ignore_reinit_error=True)
421+
422+
def test_single_request(self):
423+
transport = RayTransport()
424+
client = transport.client()
425+
policy = _make_policy()
426+
with InferenceServer(policy, transport, max_batch_size=4):
427+
td = TensorDict({"observation": torch.randn(4)})
428+
result = client(td)
429+
assert "action" in result.keys()
430+
assert result["action"].shape == (2,)
431+
432+
def test_concurrent_clients(self):
433+
"""Multiple clients submit concurrently from threads (simulating Ray actors)."""
434+
transport = RayTransport()
435+
policy = _make_policy()
436+
n_clients = 4
437+
n_requests = 20
438+
439+
clients = [transport.client() for _ in range(n_clients)]
440+
results_per_client: list[list[TensorDictBase]] = [[] for _ in range(n_clients)]
441+
442+
def client_fn(client_idx):
443+
for _ in range(n_requests):
444+
td = TensorDict({"observation": torch.randn(4)})
445+
result = clients[client_idx](td)
446+
results_per_client[client_idx].append(result)
447+
448+
with InferenceServer(policy, transport, max_batch_size=8):
449+
with concurrent.futures.ThreadPoolExecutor(max_workers=n_clients) as pool:
450+
futs = [pool.submit(client_fn, i) for i in range(n_clients)]
451+
concurrent.futures.wait(futs)
452+
for f in futs:
453+
f.result()
454+
455+
for client_results in results_per_client:
456+
assert len(client_results) == n_requests
457+
for r in client_results:
458+
assert "action" in r.keys()
459+
assert r["action"].shape == (2,)
460+
461+
def test_ray_remote_actor(self):
462+
"""A Ray remote actor can use the client to get inference results."""
463+
transport = RayTransport()
464+
client = transport.client()
465+
policy = _make_policy()
466+
467+
@ray.remote
468+
def remote_actor_fn(client, n_requests):
469+
results = []
470+
for _ in range(n_requests):
471+
td = TensorDict({"observation": torch.randn(4)})
472+
result = client(td)
473+
results.append(result["action"].shape)
474+
return results
475+
476+
with InferenceServer(policy, transport, max_batch_size=8):
477+
ref = remote_actor_fn.remote(client, 5)
478+
shapes = ray.get(ref, timeout=30.0)
479+
assert len(shapes) == 5
480+
for s in shapes:
481+
assert s == (2,)
482+
483+
def test_ray_exception_propagates(self):
484+
def bad_model(td):
485+
raise ValueError("ray model error")
486+
487+
transport = RayTransport()
488+
client = transport.client()
489+
with InferenceServer(bad_model, transport, max_batch_size=4):
490+
td = TensorDict({"observation": torch.randn(4)})
491+
with pytest.raises(ValueError, match="ray model error"):
492+
client(td)

torchrl/modules/inference_server/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
from torchrl.modules.inference_server._mp import MPTransport
7+
from torchrl.modules.inference_server._ray import RayTransport
78
from torchrl.modules.inference_server._server import InferenceClient, InferenceServer
89
from torchrl.modules.inference_server._threading import ThreadingTransport
910
from torchrl.modules.inference_server._transport import InferenceTransport
@@ -13,5 +14,6 @@
1314
"InferenceServer",
1415
"InferenceTransport",
1516
"MPTransport",
17+
"RayTransport",
1618
"ThreadingTransport",
1719
]
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
from torchrl.modules.inference_server._queue_transport import (
8+
_QueueInferenceClient,
9+
QueueBasedTransport,
10+
)
11+
12+
13+
class RayTransport(QueueBasedTransport):
14+
"""Transport using Ray queues for distributed inference.
15+
16+
Uses ``ray.util.queue.Queue`` for both request submission and response
17+
routing. Per-actor response queues ensure correct result routing without
18+
serialising Queue objects through other queues.
19+
20+
Ray is imported lazily at instantiation time; importing the class itself
21+
does not require Ray.
22+
23+
Keyword Args:
24+
max_queue_size (int): maximum size of the request queue.
25+
Default: ``1000``.
26+
27+
Example:
28+
>>> import ray
29+
>>> ray.init()
30+
>>> transport = RayTransport()
31+
>>> client = transport.client()
32+
>>> # pass *client* to a Ray actor for remote inference requests
33+
"""
34+
35+
def __init__(self, *, max_queue_size: int = 1000):
36+
super().__init__()
37+
try:
38+
import ray.util.queue
39+
except ImportError:
40+
raise ImportError(
41+
"Ray is required for RayTransport. Install it with: pip install ray"
42+
)
43+
self._request_queue = ray.util.queue.Queue(maxsize=max_queue_size)
44+
self._response_queues: dict[int, ray.util.queue.Queue] = {}
45+
self._ray_queue_module = ray.util.queue
46+
47+
def _make_response_queue(self):
48+
return self._ray_queue_module.Queue(maxsize=1000)
49+
50+
def client(self) -> _QueueInferenceClient:
51+
"""Create an actor-side client with a dedicated Ray response queue.
52+
53+
Returns:
54+
A :class:`_QueueInferenceClient` that can be used inside any Ray
55+
actor or the driver process.
56+
"""
57+
return super().client()

0 commit comments

Comments
 (0)