diff --git a/test/test_inference_server.py b/test/test_inference_server.py index ca4525e1813..185aa4335d4 100644 --- a/test/test_inference_server.py +++ b/test/test_inference_server.py @@ -23,6 +23,7 @@ RayTransport, ThreadingTransport, ) +from torchrl.modules.inference_server._monarch import MonarchTransport _has_ray = True try: @@ -30,6 +31,12 @@ except ImportError: _has_ray = False +_has_monarch = True +try: + import monarch # noqa: F401 +except ImportError: + _has_monarch = False + # ============================================================================= # Helpers @@ -490,3 +497,63 @@ def bad_model(td): td = TensorDict({"observation": torch.randn(4)}) with pytest.raises(ValueError, match="ray model error"): client(td) + + +# ============================================================================= +# Tests: MonarchTransport (Commit 5) +# ============================================================================= + + +@pytest.mark.skipif(not _has_monarch, reason="monarch not installed") +class TestMonarchTransport: + def test_single_request(self): + transport = MonarchTransport() + client = transport.client() + policy = _make_policy() + with InferenceServer(policy, transport, max_batch_size=4): + td = TensorDict({"observation": torch.randn(4)}) + result = client(td) + assert "action" in result.keys() + assert result["action"].shape == (2,) + + def test_concurrent_clients(self): + """Multiple Monarch clients submit concurrently.""" + transport = MonarchTransport() + policy = _make_policy() + n_clients = 4 + n_requests = 20 + + clients = [transport.client() for _ in range(n_clients)] + results_per_client: list[list[TensorDictBase]] = [[] for _ in range(n_clients)] + + def client_fn(client_idx): + for _ in range(n_requests): + td = TensorDict({"observation": torch.randn(4)}) + result = clients[client_idx](td) + results_per_client[client_idx].append(result) + + with InferenceServer(policy, transport, max_batch_size=8): + with concurrent.futures.ThreadPoolExecutor(max_workers=n_clients) as pool: + futs = [pool.submit(client_fn, i) for i in range(n_clients)] + concurrent.futures.wait(futs) + for f in futs: + f.result() + + for client_results in results_per_client: + assert len(client_results) == n_requests + for r in client_results: + assert "action" in r.keys() + assert r["action"].shape == (2,) + + +class TestMonarchTransportImport: + def test_import_without_monarch(self): + """MonarchTransport class can be imported even without monarch.""" + # This test verifies the lazy import pattern works. + # The class itself is importable; only instantiation requires monarch. + assert MonarchTransport is not None + + @pytest.mark.skipif(_has_monarch, reason="test requires monarch NOT installed") + def test_instantiation_without_monarch_raises(self): + with pytest.raises(ImportError, match="Monarch is required"): + MonarchTransport() diff --git a/torchrl/modules/inference_server/__init__.py b/torchrl/modules/inference_server/__init__.py index 99bdf0699e4..e68f98626fd 100644 --- a/torchrl/modules/inference_server/__init__.py +++ b/torchrl/modules/inference_server/__init__.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from torchrl.modules.inference_server._monarch import MonarchTransport from torchrl.modules.inference_server._mp import MPTransport from torchrl.modules.inference_server._ray import RayTransport from torchrl.modules.inference_server._server import InferenceClient, InferenceServer @@ -13,6 +14,7 @@ "InferenceClient", "InferenceServer", "InferenceTransport", + "MonarchTransport", "MPTransport", "RayTransport", "ThreadingTransport", diff --git a/torchrl/modules/inference_server/_monarch.py b/torchrl/modules/inference_server/_monarch.py new file mode 100644 index 00000000000..4729212c3c7 --- /dev/null +++ b/torchrl/modules/inference_server/_monarch.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from torchrl.modules.inference_server._queue_transport import ( + _QueueInferenceClient, + QueueBasedTransport, +) + + +class MonarchTransport(QueueBasedTransport): + """Transport using Monarch for distributed inference on GPU clusters. + + Uses Monarch's actor model and RDMA-capable channels for efficient + cross-node communication. Monarch is imported lazily at instantiation + time; importing the class itself does not require Monarch. + + .. note:: + This transport requires ``monarch`` to be installed. It is designed + for large-scale GPU clusters where Monarch is the preferred + communication layer. + + Keyword Args: + max_queue_size (int): maximum size of the request queue. + Default: ``1000``. + """ + + def __init__(self, *, max_queue_size: int = 1000): + super().__init__() + try: + import monarch # noqa: F401 + from monarch.tools.queue import MonarchQueue + except ImportError: + raise ImportError( + "Monarch is required for MonarchTransport. " + "Install it following the Monarch documentation." + ) + self._request_queue = MonarchQueue(maxsize=max_queue_size) + self._response_queues: dict[int, MonarchQueue] = {} + self._MonarchQueue = MonarchQueue + + def _make_response_queue(self): + return self._MonarchQueue(maxsize=1000) + + def client(self) -> _QueueInferenceClient: + """Create an actor-side client with a dedicated response queue. + + Returns: + A :class:`_QueueInferenceClient` that can be passed to a Monarch + actor. + """ + return super().client()