Skip to content

Commit 7abafc5

Browse files
vmoenscursoragent
andcommitted
[Feature] Auto-batching inference server: multiprocessing transport (#3494)
Adds MPTransport using per-actor response queues for cross-process communication. Clients must be created before spawning child processes so that mp.Queue objects are inherited, not serialised. Co-authored-by: Cursor <cursoragent@cursor.com> ghstack-source-id: b22e6e7 Pull-Request: #3494 Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent de70554 commit 7abafc5

File tree

4 files changed

+355
-0
lines changed

4 files changed

+355
-0
lines changed

test/test_inference_server.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
InferenceClient,
2020
InferenceServer,
2121
InferenceTransport,
22+
MPTransport,
2223
ThreadingTransport,
2324
)
2425

@@ -316,3 +317,84 @@ def bad_model(td):
316317
td = TensorDict({"observation": torch.randn(4)})
317318
with pytest.raises(ValueError, match="model error"):
318319
client(td)
320+
321+
322+
# =============================================================================
323+
# Tests: MPTransport (Commit 3)
324+
# =============================================================================
325+
326+
327+
def _mp_actor_fn(client, obs_size, act_size, n_requests, result_queue):
328+
"""Actor function that runs in a child process."""
329+
for _ in range(n_requests):
330+
td = TensorDict({"observation": torch.randn(obs_size)})
331+
result = client(td)
332+
assert "action" in result.keys()
333+
assert result["action"].shape == (act_size,)
334+
result_queue.put(True)
335+
336+
337+
class TestMPTransport:
338+
@pytest.mark.slow
339+
def test_single_request_in_process(self):
340+
"""MPTransport client works from the parent process."""
341+
import multiprocessing as mp
342+
343+
ctx = mp.get_context("spawn")
344+
transport = MPTransport(ctx=ctx)
345+
client = transport.client()
346+
policy = _make_policy()
347+
with InferenceServer(policy, transport, max_batch_size=4):
348+
td = TensorDict({"observation": torch.randn(4)})
349+
result = client(td)
350+
assert "action" in result.keys()
351+
assert result["action"].shape == (2,)
352+
353+
@pytest.mark.slow
354+
def test_cross_process_actors(self):
355+
"""Actors in separate processes get correct results."""
356+
import multiprocessing as mp
357+
358+
ctx = mp.get_context("spawn")
359+
transport = MPTransport(ctx=ctx)
360+
policy = _make_policy()
361+
n_actors = 2
362+
n_requests = 10
363+
364+
result_queue = ctx.Queue()
365+
# Create clients before spawning (queues inherited)
366+
clients = [transport.client() for _ in range(n_actors)]
367+
368+
with InferenceServer(policy, transport, max_batch_size=8):
369+
procs = []
370+
for i in range(n_actors):
371+
p = ctx.Process(
372+
target=_mp_actor_fn,
373+
args=(clients[i], 4, 2, n_requests, result_queue),
374+
)
375+
p.start()
376+
procs.append(p)
377+
378+
for p in procs:
379+
p.join(timeout=30.0)
380+
assert p.exitcode == 0
381+
382+
# All actors reported success
383+
for _ in range(n_actors):
384+
assert result_queue.get(timeout=1.0) is True
385+
386+
@pytest.mark.slow
387+
def test_mp_exception_propagates(self):
388+
"""Model exceptions propagate through MPTransport."""
389+
import multiprocessing as mp
390+
391+
def bad_model(td):
392+
raise ValueError("mp model error")
393+
394+
ctx = mp.get_context("spawn")
395+
transport = MPTransport(ctx=ctx)
396+
client = transport.client()
397+
with InferenceServer(bad_model, transport, max_batch_size=4):
398+
td = TensorDict({"observation": torch.randn(4)})
399+
with pytest.raises(ValueError, match="mp model error"):
400+
client(td)

torchrl/modules/inference_server/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from torchrl.modules.inference_server._mp import MPTransport
67
from torchrl.modules.inference_server._server import InferenceClient, InferenceServer
78
from torchrl.modules.inference_server._threading import ThreadingTransport
89
from torchrl.modules.inference_server._transport import InferenceTransport
@@ -11,5 +12,6 @@
1112
"InferenceClient",
1213
"InferenceServer",
1314
"InferenceTransport",
15+
"MPTransport",
1416
"ThreadingTransport",
1517
]
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
import multiprocessing as mp
8+
9+
from torchrl.modules.inference_server._queue_transport import (
10+
_QueueInferenceClient,
11+
QueueBasedTransport,
12+
)
13+
14+
15+
class MPTransport(QueueBasedTransport):
16+
"""Cross-process transport using :mod:`multiprocessing` queues.
17+
18+
Response routing uses per-actor queues (one per :meth:`client` call) so
19+
that no ``mp.Queue`` object is ever serialised through another queue.
20+
Clients must be created with :meth:`client` **before** spawning child
21+
processes.
22+
23+
Args:
24+
ctx: a multiprocessing context (e.g. ``mp.get_context("spawn")``).
25+
Defaults to ``mp.get_context("spawn")``.
26+
27+
Example:
28+
>>> import multiprocessing as mp
29+
>>> transport = MPTransport()
30+
>>> client = transport.client() # creates response queue
31+
>>> p = mp.Process(target=actor_fn, args=(client,))
32+
>>> p.start() # queue inherited
33+
"""
34+
35+
def __init__(self, ctx: mp.context.BaseContext | None = None):
36+
super().__init__()
37+
self._ctx = ctx if ctx is not None else mp.get_context("spawn")
38+
self._request_queue: mp.Queue = self._ctx.Queue()
39+
self._response_queues: dict[int, mp.Queue] = {}
40+
41+
def _make_response_queue(self) -> mp.Queue:
42+
return self._ctx.Queue()
43+
44+
def client(self) -> _QueueInferenceClient:
45+
"""Create an actor-side client with a dedicated response queue.
46+
47+
Must be called in the parent process **before** spawning children.
48+
49+
Returns:
50+
A :class:`_QueueInferenceClient` that can be passed to a child
51+
process as an argument to :class:`multiprocessing.Process`.
52+
"""
53+
return super().client()
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""Shared base classes for queue-based inference transports.
6+
7+
:class:`QueueBasedTransport` factors out the common submit/drain/resolve logic
8+
used by :class:`~torchrl.modules.inference_server.MPTransport`,
9+
:class:`~torchrl.modules.inference_server.RayTransport`, and
10+
:class:`~torchrl.modules.inference_server.MonarchTransport`. Each concrete
11+
subclass only needs to supply the queue objects (request + per-actor response).
12+
"""
13+
from __future__ import annotations
14+
15+
import queue
16+
import threading
17+
import time
18+
from typing import Any
19+
20+
from tensordict.base import TensorDictBase
21+
22+
from torchrl.modules.inference_server._transport import InferenceTransport
23+
24+
_SENTINEL = object()
25+
26+
27+
class _QueueFuture:
28+
"""Future-like object backed by a per-actor response queue.
29+
30+
The future retrieves its result by request-id so that out-of-order
31+
``result()`` calls work correctly.
32+
33+
Args:
34+
client: the :class:`_QueueInferenceClient` that created this future.
35+
req_id: the unique request identifier within that client.
36+
"""
37+
38+
def __init__(self, client: _QueueInferenceClient, req_id: int):
39+
self._client = client
40+
self._req_id = req_id
41+
self._result: Any = _SENTINEL
42+
43+
def done(self) -> bool:
44+
"""Return ``True`` if the result is available without blocking."""
45+
if self._result is not _SENTINEL:
46+
return True
47+
try:
48+
self._result = self._client._get_result(self._req_id, timeout=0)
49+
except queue.Empty:
50+
return False
51+
return True
52+
53+
def result(self, timeout: float | None = None) -> TensorDictBase:
54+
"""Block until the result is available.
55+
56+
Args:
57+
timeout: seconds to wait. ``None`` waits indefinitely.
58+
59+
Raises:
60+
queue.Empty: if *timeout* expires before a result arrives.
61+
Exception: if the server set an exception instead of a result.
62+
"""
63+
if self._result is _SENTINEL:
64+
self._result = self._client._get_result(self._req_id, timeout=timeout)
65+
if isinstance(self._result, BaseException):
66+
raise self._result
67+
return self._result
68+
69+
70+
class _QueueInferenceClient:
71+
"""Actor-side client for :class:`QueueBasedTransport`.
72+
73+
Each client owns a dedicated response queue and routes results by
74+
request-id.
75+
76+
Args:
77+
request_queue: the shared request queue (any object with ``.put()``).
78+
response_queue: this client's dedicated response queue.
79+
actor_id: the unique identifier assigned by the transport.
80+
"""
81+
82+
def __init__(self, request_queue, response_queue, actor_id: int):
83+
self._request_queue = request_queue
84+
self._response_queue = response_queue
85+
self._actor_id = actor_id
86+
self._next_req_id = 0
87+
self._buffered: dict[int, Any] = {}
88+
89+
def __call__(self, td: TensorDictBase) -> TensorDictBase:
90+
"""Submit a request and block until the result is ready."""
91+
return self.submit(td).result()
92+
93+
def submit(self, td: TensorDictBase) -> _QueueFuture:
94+
"""Submit a request and return a :class:`_QueueFuture`."""
95+
req_id = self._next_req_id
96+
self._next_req_id += 1
97+
self._request_queue.put((self._actor_id, req_id, td))
98+
return _QueueFuture(self, req_id)
99+
100+
# -- internal -------------------------------------------------------------
101+
102+
def _get_result(self, req_id: int, timeout: float | None = None) -> Any:
103+
"""Return the result for *req_id*, buffering any earlier arrivals."""
104+
if req_id in self._buffered:
105+
return self._buffered.pop(req_id)
106+
deadline = None if timeout is None else time.monotonic() + timeout
107+
while True:
108+
remaining = None
109+
if deadline is not None:
110+
remaining = deadline - time.monotonic()
111+
if remaining <= 0:
112+
raise queue.Empty(f"Timeout waiting for result of request {req_id}")
113+
try:
114+
rid, result = self._response_queue.get(timeout=remaining)
115+
except Exception as e:
116+
raise queue.Empty(
117+
f"Timeout waiting for result of request {req_id}"
118+
) from e
119+
if rid == req_id:
120+
return result
121+
self._buffered[rid] = result
122+
123+
124+
class QueueBasedTransport(InferenceTransport):
125+
"""Base class for transports that use a request queue and per-actor response queues.
126+
127+
Subclasses must set the following attributes in ``__init__`` (before or
128+
after calling ``super().__init__()``):
129+
130+
* ``_request_queue`` -- the shared request queue (any object with
131+
``.put()``, ``.get(timeout=...)``, and ``.get(block=False)``).
132+
* ``_response_queues`` -- a ``dict[int, <queue>]`` mapping actor ids to
133+
per-actor response queues.
134+
135+
Subclasses must implement:
136+
137+
* :meth:`_make_response_queue` -- factory for creating a new response queue.
138+
139+
.. note::
140+
``wait_for_work`` uses a blocking ``get`` to detect new work. The
141+
retrieved item is stored in ``_peeked`` and consumed by the next
142+
``drain`` call, preserving FIFO ordering.
143+
"""
144+
145+
def __init__(self):
146+
self._lock = threading.Lock()
147+
self._next_actor_id = 0
148+
self._peeked = None
149+
150+
# -- to be implemented by subclasses --------------------------------------
151+
152+
def _make_response_queue(self):
153+
"""Create a new response queue for an actor."""
154+
raise NotImplementedError
155+
156+
# -- actor API ------------------------------------------------------------
157+
158+
def client(self) -> _QueueInferenceClient:
159+
"""Create an actor-side client with a dedicated response queue.
160+
161+
Returns:
162+
A :class:`_QueueInferenceClient`.
163+
"""
164+
with self._lock:
165+
actor_id = self._next_actor_id
166+
self._next_actor_id += 1
167+
response_queue = self._make_response_queue()
168+
self._response_queues[actor_id] = response_queue
169+
return _QueueInferenceClient(self._request_queue, response_queue, actor_id)
170+
171+
def submit(self, td: TensorDictBase):
172+
"""Not supported -- use :meth:`client` to obtain an actor handle."""
173+
raise RuntimeError(
174+
f"{type(self).__name__}.submit() is not supported. "
175+
"Call transport.client() to create a client."
176+
)
177+
178+
# -- server API -----------------------------------------------------------
179+
180+
def drain(
181+
self, max_items: int
182+
) -> tuple[list[TensorDictBase], list[tuple[int, int]]]:
183+
"""Dequeue up to *max_items* pending requests (non-blocking)."""
184+
items: list[TensorDictBase] = []
185+
callbacks: list[tuple[int, int]] = []
186+
peeked = self._peeked
187+
if peeked is not None:
188+
self._peeked = None
189+
actor_id, req_id, td = peeked
190+
items.append(td)
191+
callbacks.append((actor_id, req_id))
192+
for _ in range(max_items - len(items)):
193+
try:
194+
actor_id, req_id, td = self._request_queue.get(block=False)
195+
except Exception:
196+
break
197+
items.append(td)
198+
callbacks.append((actor_id, req_id))
199+
return items, callbacks
200+
201+
def wait_for_work(self, timeout: float) -> None:
202+
"""Block until at least one request is available or *timeout* elapses."""
203+
if self._peeked is not None:
204+
return
205+
try:
206+
self._peeked = self._request_queue.get(timeout=timeout)
207+
except Exception:
208+
pass
209+
210+
def resolve(self, callback: tuple[int, int], result: TensorDictBase) -> None:
211+
"""Route the result to the correct actor's response queue."""
212+
actor_id, req_id = callback
213+
self._response_queues[actor_id].put((req_id, result))
214+
215+
def resolve_exception(self, callback: tuple[int, int], exc: BaseException) -> None:
216+
"""Route an exception to the correct actor's response queue."""
217+
actor_id, req_id = callback
218+
self._response_queues[actor_id].put((req_id, exc))

0 commit comments

Comments
 (0)