Skip to content

Commit 60f61bd

Browse files
committed
Update
[ghstack-poisoned]
1 parent 98662eb commit 60f61bd

File tree

3 files changed

+280
-0
lines changed

3 files changed

+280
-0
lines changed

test/test_inference_server.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
InferenceClient,
2020
InferenceServer,
2121
InferenceTransport,
22+
MPTransport,
2223
ThreadingTransport,
2324
)
2425

@@ -316,3 +317,84 @@ def bad_model(td):
316317
td = TensorDict({"observation": torch.randn(4)})
317318
with pytest.raises(ValueError, match="model error"):
318319
client(td)
320+
321+
322+
# =============================================================================
323+
# Tests: MPTransport (Commit 3)
324+
# =============================================================================
325+
326+
327+
def _mp_actor_fn(client, obs_size, act_size, n_requests, result_queue):
328+
"""Actor function that runs in a child process."""
329+
for _ in range(n_requests):
330+
td = TensorDict({"observation": torch.randn(obs_size)})
331+
result = client(td)
332+
assert "action" in result.keys()
333+
assert result["action"].shape == (act_size,)
334+
result_queue.put(True)
335+
336+
337+
class TestMPTransport:
338+
@pytest.mark.slow
339+
def test_single_request_in_process(self):
340+
"""MPTransport client works from the parent process."""
341+
import multiprocessing as mp
342+
343+
ctx = mp.get_context("spawn")
344+
transport = MPTransport(ctx=ctx)
345+
client = transport.client()
346+
policy = _make_policy()
347+
with InferenceServer(policy, transport, max_batch_size=4):
348+
td = TensorDict({"observation": torch.randn(4)})
349+
result = client(td)
350+
assert "action" in result.keys()
351+
assert result["action"].shape == (2,)
352+
353+
@pytest.mark.slow
354+
def test_cross_process_actors(self):
355+
"""Actors in separate processes get correct results."""
356+
import multiprocessing as mp
357+
358+
ctx = mp.get_context("spawn")
359+
transport = MPTransport(ctx=ctx)
360+
policy = _make_policy()
361+
n_actors = 2
362+
n_requests = 10
363+
364+
result_queue = ctx.Queue()
365+
# Create clients before spawning (queues inherited)
366+
clients = [transport.client() for _ in range(n_actors)]
367+
368+
with InferenceServer(policy, transport, max_batch_size=8):
369+
procs = []
370+
for i in range(n_actors):
371+
p = ctx.Process(
372+
target=_mp_actor_fn,
373+
args=(clients[i], 4, 2, n_requests, result_queue),
374+
)
375+
p.start()
376+
procs.append(p)
377+
378+
for p in procs:
379+
p.join(timeout=30.0)
380+
assert p.exitcode == 0
381+
382+
# All actors reported success
383+
for _ in range(n_actors):
384+
assert result_queue.get(timeout=1.0) is True
385+
386+
@pytest.mark.slow
387+
def test_mp_exception_propagates(self):
388+
"""Model exceptions propagate through MPTransport."""
389+
import multiprocessing as mp
390+
391+
def bad_model(td):
392+
raise ValueError("mp model error")
393+
394+
ctx = mp.get_context("spawn")
395+
transport = MPTransport(ctx=ctx)
396+
client = transport.client()
397+
with InferenceServer(bad_model, transport, max_batch_size=4):
398+
td = TensorDict({"observation": torch.randn(4)})
399+
with pytest.raises(ValueError, match="mp model error"):
400+
client(td)

torchrl/modules/inference_server/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from torchrl.modules.inference_server._mp import MPTransport
67
from torchrl.modules.inference_server._server import InferenceClient, InferenceServer
78
from torchrl.modules.inference_server._threading import ThreadingTransport
89
from torchrl.modules.inference_server._transport import InferenceTransport
@@ -11,5 +12,6 @@
1112
"InferenceClient",
1213
"InferenceServer",
1314
"InferenceTransport",
15+
"MPTransport",
1416
"ThreadingTransport",
1517
]
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
import multiprocessing as mp
8+
import queue
9+
import threading
10+
import time
11+
from typing import Any
12+
13+
from tensordict.base import TensorDictBase
14+
15+
from torchrl.modules.inference_server._transport import InferenceTransport
16+
17+
_SENTINEL = object()
18+
19+
20+
class _MPFuture:
21+
"""Future-like object backed by a per-actor response queue.
22+
23+
The future retrieves its result by request-id so that out-of-order
24+
``result()`` calls work correctly.
25+
26+
Args:
27+
client: the :class:`_MPInferenceClient` that created this future.
28+
req_id: the unique request identifier within that client.
29+
"""
30+
31+
def __init__(self, client: _MPInferenceClient, req_id: int):
32+
self._client = client
33+
self._req_id = req_id
34+
self._result: Any = _SENTINEL
35+
36+
def result(self, timeout: float | None = None) -> TensorDictBase:
37+
"""Block until the result is available.
38+
39+
Args:
40+
timeout: seconds to wait. ``None`` waits indefinitely.
41+
42+
Raises:
43+
queue.Empty: if *timeout* expires before a result arrives.
44+
Exception: if the server set an exception instead of a result.
45+
"""
46+
if self._result is _SENTINEL:
47+
item = self._client._get_result(self._req_id, timeout=timeout)
48+
if isinstance(item, BaseException):
49+
raise item
50+
self._result = item
51+
return self._result
52+
53+
54+
class _MPInferenceClient:
55+
"""Actor-side client for :class:`MPTransport`.
56+
57+
Each client owns a dedicated response queue and routes results by
58+
request-id. Instances are created by :meth:`MPTransport.client` and
59+
must be created **before** spawning child processes so that the
60+
underlying queues are inherited.
61+
62+
Args:
63+
request_queue: the shared request queue.
64+
response_queue: this client's dedicated response queue.
65+
actor_id: the unique identifier assigned by the transport.
66+
"""
67+
68+
def __init__(
69+
self,
70+
request_queue: mp.Queue,
71+
response_queue: mp.Queue,
72+
actor_id: int,
73+
):
74+
self._request_queue = request_queue
75+
self._response_queue = response_queue
76+
self._actor_id = actor_id
77+
self._next_req_id = 0
78+
self._buffered: dict[int, Any] = {}
79+
80+
def __call__(self, td: TensorDictBase) -> TensorDictBase:
81+
"""Submit a request and block until the result is ready."""
82+
return self.submit(td).result()
83+
84+
def submit(self, td: TensorDictBase) -> _MPFuture:
85+
"""Submit a request and return an :class:`_MPFuture`."""
86+
req_id = self._next_req_id
87+
self._next_req_id += 1
88+
self._request_queue.put((self._actor_id, req_id, td))
89+
return _MPFuture(self, req_id)
90+
91+
# -- internal -------------------------------------------------------------
92+
93+
def _get_result(self, req_id: int, timeout: float | None = None) -> Any:
94+
"""Return the result for *req_id*, buffering any earlier arrivals."""
95+
if req_id in self._buffered:
96+
return self._buffered.pop(req_id)
97+
deadline = None if timeout is None else time.monotonic() + timeout
98+
while True:
99+
remaining = None
100+
if deadline is not None:
101+
remaining = deadline - time.monotonic()
102+
if remaining <= 0:
103+
raise queue.Empty(f"Timeout waiting for result of request {req_id}")
104+
rid, result = self._response_queue.get(timeout=remaining)
105+
if rid == req_id:
106+
return result
107+
self._buffered[rid] = result
108+
109+
110+
class MPTransport(InferenceTransport):
111+
"""Cross-process transport using :mod:`multiprocessing` queues.
112+
113+
Response routing uses per-actor queues (one per :meth:`client` call) so
114+
that no ``mp.Queue`` object is ever serialised through another queue.
115+
Clients must be created with :meth:`client` **before** spawning child
116+
processes.
117+
118+
Args:
119+
ctx: a multiprocessing context (e.g. ``mp.get_context("spawn")``).
120+
Defaults to ``mp.get_context("spawn")``.
121+
122+
Example:
123+
>>> import multiprocessing as mp
124+
>>> transport = MPTransport()
125+
>>> client = transport.client() # creates response queue
126+
>>> p = mp.Process(target=actor_fn, args=(client,))
127+
>>> p.start() # queue inherited
128+
"""
129+
130+
def __init__(self, ctx: mp.context.BaseContext | None = None):
131+
self._ctx = ctx if ctx is not None else mp.get_context("spawn")
132+
self._request_queue: mp.Queue = self._ctx.Queue()
133+
self._response_queues: dict[int, mp.Queue] = {}
134+
self._lock = threading.Lock()
135+
self._next_actor_id = 0
136+
137+
# -- actor API (called before fork) ---------------------------------------
138+
139+
def client(self) -> _MPInferenceClient:
140+
"""Create an actor-side client with a dedicated response queue.
141+
142+
Must be called in the parent process **before** spawning children.
143+
144+
Returns:
145+
An :class:`_MPInferenceClient` that can be passed to a child
146+
process as an argument to :class:`multiprocessing.Process`.
147+
"""
148+
with self._lock:
149+
actor_id = self._next_actor_id
150+
self._next_actor_id += 1
151+
response_queue: mp.Queue = self._ctx.Queue()
152+
self._response_queues[actor_id] = response_queue
153+
return _MPInferenceClient(self._request_queue, response_queue, actor_id)
154+
155+
def submit(self, td: TensorDictBase):
156+
"""Not supported -- use :meth:`client` to obtain an actor handle."""
157+
raise RuntimeError(
158+
"MPTransport.submit() is not supported. "
159+
"Call transport.client() to create an _MPInferenceClient."
160+
)
161+
162+
# -- server API -----------------------------------------------------------
163+
164+
def drain(
165+
self, max_items: int
166+
) -> tuple[list[TensorDictBase], list[tuple[int, int]]]:
167+
"""Dequeue up to *max_items* pending ``(actor_id, req_id, td)`` tuples."""
168+
items: list[TensorDictBase] = []
169+
callbacks: list[tuple[int, int]] = []
170+
for _ in range(max_items):
171+
try:
172+
actor_id, req_id, td = self._request_queue.get_nowait()
173+
items.append(td)
174+
callbacks.append((actor_id, req_id))
175+
except queue.Empty:
176+
break
177+
return items, callbacks
178+
179+
def wait_for_work(self, timeout: float) -> None:
180+
"""Block until at least one request is available or *timeout* elapses."""
181+
try:
182+
item = self._request_queue.get(timeout=timeout)
183+
# Put it back so drain() can consume it.
184+
self._request_queue.put(item)
185+
except queue.Empty:
186+
pass
187+
188+
def resolve(self, callback: tuple[int, int], result: TensorDictBase) -> None:
189+
"""Route the result to the correct actor's response queue."""
190+
actor_id, req_id = callback
191+
self._response_queues[actor_id].put((req_id, result))
192+
193+
def resolve_exception(self, callback: tuple[int, int], exc: BaseException) -> None:
194+
"""Route an exception to the correct actor's response queue."""
195+
actor_id, req_id = callback
196+
self._response_queues[actor_id].put((req_id, exc))

0 commit comments

Comments
 (0)