Skip to content

Commit 80833ec

Browse files
vmoenscursoragent
andcommitted
[Feature] Auto-batching inference server: Monarch transport
Adds MonarchTransport for distributed inference on GPU clusters using Monarch's actor model and RDMA channels. Monarch is imported lazily at instantiation time. Co-authored-by: Cursor <cursoragent@cursor.com> ghstack-source-id: 973026e Pull-Request: #3496
1 parent 3289df6 commit 80833ec

File tree

3 files changed

+251
-0
lines changed

3 files changed

+251
-0
lines changed

test/test_inference_server.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,20 @@
2323
RayTransport,
2424
ThreadingTransport,
2525
)
26+
from torchrl.modules.inference_server._monarch import MonarchTransport
2627

2728
_has_ray = True
2829
try:
2930
import ray
3031
except ImportError:
3132
_has_ray = False
3233

34+
_has_monarch = True
35+
try:
36+
import monarch # noqa: F401
37+
except ImportError:
38+
_has_monarch = False
39+
3340

3441
# =============================================================================
3542
# Helpers
@@ -490,3 +497,63 @@ def bad_model(td):
490497
td = TensorDict({"observation": torch.randn(4)})
491498
with pytest.raises(ValueError, match="ray model error"):
492499
client(td)
500+
501+
502+
# =============================================================================
503+
# Tests: MonarchTransport (Commit 5)
504+
# =============================================================================
505+
506+
507+
@pytest.mark.skipif(not _has_monarch, reason="monarch not installed")
508+
class TestMonarchTransport:
509+
def test_single_request(self):
510+
transport = MonarchTransport()
511+
client = transport.client()
512+
policy = _make_policy()
513+
with InferenceServer(policy, transport, max_batch_size=4):
514+
td = TensorDict({"observation": torch.randn(4)})
515+
result = client(td)
516+
assert "action" in result.keys()
517+
assert result["action"].shape == (2,)
518+
519+
def test_concurrent_clients(self):
520+
"""Multiple Monarch clients submit concurrently."""
521+
transport = MonarchTransport()
522+
policy = _make_policy()
523+
n_clients = 4
524+
n_requests = 20
525+
526+
clients = [transport.client() for _ in range(n_clients)]
527+
results_per_client: list[list[TensorDictBase]] = [[] for _ in range(n_clients)]
528+
529+
def client_fn(client_idx):
530+
for _ in range(n_requests):
531+
td = TensorDict({"observation": torch.randn(4)})
532+
result = clients[client_idx](td)
533+
results_per_client[client_idx].append(result)
534+
535+
with InferenceServer(policy, transport, max_batch_size=8):
536+
with concurrent.futures.ThreadPoolExecutor(max_workers=n_clients) as pool:
537+
futs = [pool.submit(client_fn, i) for i in range(n_clients)]
538+
concurrent.futures.wait(futs)
539+
for f in futs:
540+
f.result()
541+
542+
for client_results in results_per_client:
543+
assert len(client_results) == n_requests
544+
for r in client_results:
545+
assert "action" in r.keys()
546+
assert r["action"].shape == (2,)
547+
548+
549+
class TestMonarchTransportImport:
550+
def test_import_without_monarch(self):
551+
"""MonarchTransport class can be imported even without monarch."""
552+
# This test verifies the lazy import pattern works.
553+
# The class itself is importable; only instantiation requires monarch.
554+
assert MonarchTransport is not None
555+
556+
@pytest.mark.skipif(_has_monarch, reason="test requires monarch NOT installed")
557+
def test_instantiation_without_monarch_raises(self):
558+
with pytest.raises(ImportError, match="Monarch is required"):
559+
MonarchTransport()

torchrl/modules/inference_server/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from torchrl.modules.inference_server._monarch import MonarchTransport
67
from torchrl.modules.inference_server._mp import MPTransport
78
from torchrl.modules.inference_server._ray import RayTransport
89
from torchrl.modules.inference_server._server import InferenceClient, InferenceServer
@@ -13,6 +14,7 @@
1314
"InferenceClient",
1415
"InferenceServer",
1516
"InferenceTransport",
17+
"MonarchTransport",
1618
"MPTransport",
1719
"RayTransport",
1820
"ThreadingTransport",
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
import queue
8+
import threading
9+
import time
10+
from typing import Any
11+
12+
from tensordict.base import TensorDictBase
13+
14+
from torchrl.modules.inference_server._transport import InferenceTransport
15+
16+
_SENTINEL = object()
17+
18+
19+
class _MonarchFuture:
20+
"""Future-like object for Monarch transport results.
21+
22+
Args:
23+
client: the :class:`_MonarchInferenceClient` that created this future.
24+
req_id: the unique request identifier within that client.
25+
"""
26+
27+
def __init__(self, client: _MonarchInferenceClient, req_id: int):
28+
self._client = client
29+
self._req_id = req_id
30+
self._result: Any = _SENTINEL
31+
32+
def result(self, timeout: float | None = None) -> TensorDictBase:
33+
"""Block until the result is available."""
34+
if self._result is _SENTINEL:
35+
item = self._client._get_result(self._req_id, timeout=timeout)
36+
if isinstance(item, BaseException):
37+
raise item
38+
self._result = item
39+
return self._result
40+
41+
42+
class _MonarchInferenceClient:
43+
"""Actor-side client for :class:`MonarchTransport`.
44+
45+
Each client owns a dedicated response queue and routes results by
46+
request-id.
47+
48+
Args:
49+
request_queue: the shared Monarch queue for requests.
50+
response_queue: this client's dedicated response queue.
51+
actor_id: the unique identifier assigned by the transport.
52+
"""
53+
54+
def __init__(self, request_queue, response_queue, actor_id: int):
55+
self._request_queue = request_queue
56+
self._response_queue = response_queue
57+
self._actor_id = actor_id
58+
self._next_req_id = 0
59+
self._buffered: dict[int, Any] = {}
60+
61+
def __call__(self, td: TensorDictBase) -> TensorDictBase:
62+
"""Submit a request and block until the result is ready."""
63+
return self.submit(td).result()
64+
65+
def submit(self, td: TensorDictBase) -> _MonarchFuture:
66+
"""Submit a request and return a :class:`_MonarchFuture`."""
67+
req_id = self._next_req_id
68+
self._next_req_id += 1
69+
self._request_queue.put((self._actor_id, req_id, td))
70+
return _MonarchFuture(self, req_id)
71+
72+
# -- internal -------------------------------------------------------------
73+
74+
def _get_result(self, req_id: int, timeout: float | None = None) -> Any:
75+
"""Return the result for *req_id*, buffering any earlier arrivals."""
76+
if req_id in self._buffered:
77+
return self._buffered.pop(req_id)
78+
deadline = None if timeout is None else time.monotonic() + timeout
79+
while True:
80+
remaining = None
81+
if deadline is not None:
82+
remaining = deadline - time.monotonic()
83+
if remaining <= 0:
84+
raise queue.Empty(f"Timeout waiting for result of request {req_id}")
85+
try:
86+
rid, result = self._response_queue.get(timeout=remaining)
87+
except Exception:
88+
raise queue.Empty(f"Timeout waiting for result of request {req_id}")
89+
if rid == req_id:
90+
return result
91+
self._buffered[rid] = result
92+
93+
94+
class MonarchTransport(InferenceTransport):
95+
"""Transport using Monarch for distributed inference on GPU clusters.
96+
97+
Uses Monarch's actor model and RDMA-capable channels for efficient
98+
cross-node communication. Monarch is imported lazily at instantiation
99+
time; importing the class itself does not require Monarch.
100+
101+
.. note::
102+
This transport requires ``monarch`` to be installed. It is designed
103+
for large-scale GPU clusters where Monarch is the preferred
104+
communication layer.
105+
106+
Keyword Args:
107+
max_queue_size (int): maximum size of the request queue.
108+
Default: ``1000``.
109+
"""
110+
111+
def __init__(self, *, max_queue_size: int = 1000):
112+
try:
113+
import monarch # noqa: F401
114+
from monarch.tools.queue import MonarchQueue
115+
except ImportError:
116+
raise ImportError(
117+
"Monarch is required for MonarchTransport. "
118+
"Install it following the Monarch documentation."
119+
)
120+
self._request_queue = MonarchQueue(maxsize=max_queue_size)
121+
self._response_queues: dict[int, Any] = {}
122+
self._lock = threading.Lock()
123+
self._next_actor_id = 0
124+
self._MonarchQueue = MonarchQueue
125+
126+
# -- actor API ------------------------------------------------------------
127+
128+
def client(self) -> _MonarchInferenceClient:
129+
"""Create an actor-side client with a dedicated response queue.
130+
131+
Returns:
132+
A :class:`_MonarchInferenceClient` that can be passed to a Monarch
133+
actor.
134+
"""
135+
with self._lock:
136+
actor_id = self._next_actor_id
137+
self._next_actor_id += 1
138+
response_queue = self._MonarchQueue(maxsize=1000)
139+
self._response_queues[actor_id] = response_queue
140+
return _MonarchInferenceClient(self._request_queue, response_queue, actor_id)
141+
142+
def submit(self, td: TensorDictBase):
143+
"""Not supported -- use :meth:`client` to obtain an actor handle."""
144+
raise RuntimeError(
145+
"MonarchTransport.submit() is not supported. "
146+
"Call transport.client() to create a _MonarchInferenceClient."
147+
)
148+
149+
# -- server API -----------------------------------------------------------
150+
151+
def drain(
152+
self, max_items: int
153+
) -> tuple[list[TensorDictBase], list[tuple[int, int]]]:
154+
"""Dequeue up to *max_items* pending requests (non-blocking)."""
155+
items: list[TensorDictBase] = []
156+
callbacks: list[tuple[int, int]] = []
157+
for _ in range(max_items):
158+
try:
159+
actor_id, req_id, td = self._request_queue.get(block=False)
160+
items.append(td)
161+
callbacks.append((actor_id, req_id))
162+
except Exception:
163+
break
164+
return items, callbacks
165+
166+
def wait_for_work(self, timeout: float) -> None:
167+
"""Block until at least one request is available or *timeout* elapses."""
168+
try:
169+
item = self._request_queue.get(timeout=timeout)
170+
self._request_queue.put(item)
171+
except Exception:
172+
pass
173+
174+
def resolve(self, callback: tuple[int, int], result: TensorDictBase) -> None:
175+
"""Route the result to the correct actor's response queue."""
176+
actor_id, req_id = callback
177+
self._response_queues[actor_id].put((req_id, result))
178+
179+
def resolve_exception(self, callback: tuple[int, int], exc: BaseException) -> None:
180+
"""Route an exception to the correct actor's response queue."""
181+
actor_id, req_id = callback
182+
self._response_queues[actor_id].put((req_id, exc))

0 commit comments

Comments
 (0)