Skip to content

Commit 4083720

Browse files
vmoenscursoragent
andcommitted
[Feature] Auto-batching inference server: Monarch transport (#3496)
Adds MonarchTransport for distributed inference on GPU clusters using Monarch's actor model and RDMA channels. Monarch is imported lazily at instantiation time. Co-authored-by: Cursor <cursoragent@cursor.com> ghstack-source-id: 8ea2d20 Pull-Request: #3496 Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 52dc84e commit 4083720

File tree

3 files changed

+123
-0
lines changed

3 files changed

+123
-0
lines changed

test/test_inference_server.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,20 @@
2323
RayTransport,
2424
ThreadingTransport,
2525
)
26+
from torchrl.modules.inference_server._monarch import MonarchTransport
2627

2728
_has_ray = True
2829
try:
2930
import ray
3031
except ImportError:
3132
_has_ray = False
3233

34+
_has_monarch = True
35+
try:
36+
import monarch # noqa: F401
37+
except ImportError:
38+
_has_monarch = False
39+
3340

3441
# =============================================================================
3542
# Helpers
@@ -490,3 +497,63 @@ def bad_model(td):
490497
td = TensorDict({"observation": torch.randn(4)})
491498
with pytest.raises(ValueError, match="ray model error"):
492499
client(td)
500+
501+
502+
# =============================================================================
503+
# Tests: MonarchTransport (Commit 5)
504+
# =============================================================================
505+
506+
507+
@pytest.mark.skipif(not _has_monarch, reason="monarch not installed")
508+
class TestMonarchTransport:
509+
def test_single_request(self):
510+
transport = MonarchTransport()
511+
client = transport.client()
512+
policy = _make_policy()
513+
with InferenceServer(policy, transport, max_batch_size=4):
514+
td = TensorDict({"observation": torch.randn(4)})
515+
result = client(td)
516+
assert "action" in result.keys()
517+
assert result["action"].shape == (2,)
518+
519+
def test_concurrent_clients(self):
520+
"""Multiple Monarch clients submit concurrently."""
521+
transport = MonarchTransport()
522+
policy = _make_policy()
523+
n_clients = 4
524+
n_requests = 20
525+
526+
clients = [transport.client() for _ in range(n_clients)]
527+
results_per_client: list[list[TensorDictBase]] = [[] for _ in range(n_clients)]
528+
529+
def client_fn(client_idx):
530+
for _ in range(n_requests):
531+
td = TensorDict({"observation": torch.randn(4)})
532+
result = clients[client_idx](td)
533+
results_per_client[client_idx].append(result)
534+
535+
with InferenceServer(policy, transport, max_batch_size=8):
536+
with concurrent.futures.ThreadPoolExecutor(max_workers=n_clients) as pool:
537+
futs = [pool.submit(client_fn, i) for i in range(n_clients)]
538+
concurrent.futures.wait(futs)
539+
for f in futs:
540+
f.result()
541+
542+
for client_results in results_per_client:
543+
assert len(client_results) == n_requests
544+
for r in client_results:
545+
assert "action" in r.keys()
546+
assert r["action"].shape == (2,)
547+
548+
549+
class TestMonarchTransportImport:
550+
def test_import_without_monarch(self):
551+
"""MonarchTransport class can be imported even without monarch."""
552+
# This test verifies the lazy import pattern works.
553+
# The class itself is importable; only instantiation requires monarch.
554+
assert MonarchTransport is not None
555+
556+
@pytest.mark.skipif(_has_monarch, reason="test requires monarch NOT installed")
557+
def test_instantiation_without_monarch_raises(self):
558+
with pytest.raises(ImportError, match="Monarch is required"):
559+
MonarchTransport()

torchrl/modules/inference_server/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from torchrl.modules.inference_server._monarch import MonarchTransport
67
from torchrl.modules.inference_server._mp import MPTransport
78
from torchrl.modules.inference_server._ray import RayTransport
89
from torchrl.modules.inference_server._server import InferenceClient, InferenceServer
@@ -13,6 +14,7 @@
1314
"InferenceClient",
1415
"InferenceServer",
1516
"InferenceTransport",
17+
"MonarchTransport",
1618
"MPTransport",
1719
"RayTransport",
1820
"ThreadingTransport",
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
from torchrl.modules.inference_server._queue_transport import (
8+
_QueueInferenceClient,
9+
QueueBasedTransport,
10+
)
11+
12+
13+
class MonarchTransport(QueueBasedTransport):
14+
"""Transport using Monarch for distributed inference on GPU clusters.
15+
16+
Uses Monarch's actor model and RDMA-capable channels for efficient
17+
cross-node communication. Monarch is imported lazily at instantiation
18+
time; importing the class itself does not require Monarch.
19+
20+
.. note::
21+
This transport requires ``monarch`` to be installed. It is designed
22+
for large-scale GPU clusters where Monarch is the preferred
23+
communication layer.
24+
25+
Keyword Args:
26+
max_queue_size (int): maximum size of the request queue.
27+
Default: ``1000``.
28+
"""
29+
30+
def __init__(self, *, max_queue_size: int = 1000):
31+
super().__init__()
32+
try:
33+
import monarch # noqa: F401
34+
from monarch.tools.queue import MonarchQueue
35+
except ImportError:
36+
raise ImportError(
37+
"Monarch is required for MonarchTransport. "
38+
"Install it following the Monarch documentation."
39+
)
40+
self._request_queue = MonarchQueue(maxsize=max_queue_size)
41+
self._response_queues: dict[int, MonarchQueue] = {}
42+
self._MonarchQueue = MonarchQueue
43+
44+
def _make_response_queue(self):
45+
return self._MonarchQueue(maxsize=1000)
46+
47+
def client(self) -> _QueueInferenceClient:
48+
"""Create an actor-side client with a dedicated response queue.
49+
50+
Returns:
51+
A :class:`_QueueInferenceClient` that can be passed to a Monarch
52+
actor.
53+
"""
54+
return super().client()

0 commit comments

Comments
 (0)