Skip to content

Commit 3f08418

Browse files
committed
Update (base update)
[ghstack-poisoned]
1 parent f2402c1 commit 3f08418

File tree

3 files changed

+273
-305
lines changed

3 files changed

+273
-305
lines changed

torchrl/modules/inference_server/_mp.py

Lines changed: 27 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -5,118 +5,37 @@
55
from __future__ import annotations
66

77
import multiprocessing as mp
8-
import queue
98
import threading
10-
import time
11-
from typing import Any
129

13-
from tensordict.base import TensorDictBase
10+
from torchrl.modules.inference_server._queue_transport import (
11+
_QueueInferenceClient,
12+
QueueBasedTransport,
13+
)
1414

15-
from torchrl.modules.inference_server._transport import InferenceTransport
1615

17-
_SENTINEL = object()
16+
class _MPRequestQueue:
17+
"""Wrapper around ``mp.Queue`` that signals a :class:`threading.Event` on put.
1818
19-
20-
class _MPFuture:
21-
"""Future-like object backed by a per-actor response queue.
22-
23-
The future retrieves its result by request-id so that out-of-order
24-
``result()`` calls work correctly.
25-
26-
Args:
27-
client: the :class:`_MPInferenceClient` that created this future.
28-
req_id: the unique request identifier within that client.
19+
This avoids the get-then-put anti-pattern in ``wait_for_work``: instead of
20+
consuming an item just to peek, callers wait on the event.
2921
"""
3022

31-
def __init__(self, client: _MPInferenceClient, req_id: int):
32-
self._client = client
33-
self._req_id = req_id
34-
self._result: Any = _SENTINEL
35-
36-
def done(self) -> bool:
37-
"""Return ``True`` if the result is available without blocking."""
38-
if self._result is not _SENTINEL:
39-
return True
40-
try:
41-
self._result = self._client._get_result(self._req_id, timeout=0)
42-
except queue.Empty:
43-
return False
44-
return True
45-
46-
def result(self, timeout: float | None = None) -> TensorDictBase:
47-
"""Block until the result is available.
48-
49-
Args:
50-
timeout: seconds to wait. ``None`` waits indefinitely.
51-
52-
Raises:
53-
queue.Empty: if *timeout* expires before a result arrives.
54-
Exception: if the server set an exception instead of a result.
55-
"""
56-
if self._result is _SENTINEL:
57-
self._result = self._client._get_result(self._req_id, timeout=timeout)
58-
if isinstance(self._result, BaseException):
59-
raise self._result
60-
return self._result
23+
def __init__(self, ctx: mp.context.BaseContext, has_work: threading.Event):
24+
self._queue: mp.Queue = ctx.Queue()
25+
self._has_work = has_work
6126

27+
def put(self, item):
28+
self._queue.put(item)
29+
self._has_work.set()
6230

63-
class _MPInferenceClient:
64-
"""Actor-side client for :class:`MPTransport`.
31+
def get(self, timeout=None):
32+
return self._queue.get(timeout=timeout)
6533

66-
Each client owns a dedicated response queue and routes results by
67-
request-id. Instances are created by :meth:`MPTransport.client` and
68-
must be created **before** spawning child processes so that the
69-
underlying queues are inherited.
34+
def get_nowait(self):
35+
return self._queue.get_nowait()
7036

71-
Args:
72-
request_queue: the shared request queue.
73-
response_queue: this client's dedicated response queue.
74-
actor_id: the unique identifier assigned by the transport.
75-
"""
7637

77-
def __init__(
78-
self,
79-
request_queue: mp.Queue,
80-
response_queue: mp.Queue,
81-
actor_id: int,
82-
):
83-
self._request_queue = request_queue
84-
self._response_queue = response_queue
85-
self._actor_id = actor_id
86-
self._next_req_id = 0
87-
self._buffered: dict[int, Any] = {}
88-
89-
def __call__(self, td: TensorDictBase) -> TensorDictBase:
90-
"""Submit a request and block until the result is ready."""
91-
return self.submit(td).result()
92-
93-
def submit(self, td: TensorDictBase) -> _MPFuture:
94-
"""Submit a request and return an :class:`_MPFuture`."""
95-
req_id = self._next_req_id
96-
self._next_req_id += 1
97-
self._request_queue.put((self._actor_id, req_id, td))
98-
return _MPFuture(self, req_id)
99-
100-
# -- internal -------------------------------------------------------------
101-
102-
def _get_result(self, req_id: int, timeout: float | None = None) -> Any:
103-
"""Return the result for *req_id*, buffering any earlier arrivals."""
104-
if req_id in self._buffered:
105-
return self._buffered.pop(req_id)
106-
deadline = None if timeout is None else time.monotonic() + timeout
107-
while True:
108-
remaining = None
109-
if deadline is not None:
110-
remaining = deadline - time.monotonic()
111-
if remaining <= 0:
112-
raise queue.Empty(f"Timeout waiting for result of request {req_id}")
113-
rid, result = self._response_queue.get(timeout=remaining)
114-
if rid == req_id:
115-
return result
116-
self._buffered[rid] = result
117-
118-
119-
class MPTransport(InferenceTransport):
38+
class MPTransport(QueueBasedTransport):
12039
"""Cross-process transport using :mod:`multiprocessing` queues.
12140
12241
Response routing uses per-actor queues (one per :meth:`client` call) so
@@ -137,69 +56,22 @@ class MPTransport(InferenceTransport):
13756
"""
13857

13958
def __init__(self, ctx: mp.context.BaseContext | None = None):
59+
super().__init__()
14060
self._ctx = ctx if ctx is not None else mp.get_context("spawn")
141-
self._request_queue: mp.Queue = self._ctx.Queue()
61+
self._has_work = threading.Event()
62+
self._request_queue = _MPRequestQueue(self._ctx, self._has_work)
14263
self._response_queues: dict[int, mp.Queue] = {}
143-
self._lock = threading.Lock()
144-
self._next_actor_id = 0
14564

146-
# -- actor API (called before fork) ---------------------------------------
65+
def _make_response_queue(self) -> mp.Queue:
66+
return self._ctx.Queue()
14767

148-
def client(self) -> _MPInferenceClient:
68+
def client(self) -> _QueueInferenceClient:
14969
"""Create an actor-side client with a dedicated response queue.
15070
15171
Must be called in the parent process **before** spawning children.
15272
15373
Returns:
154-
An :class:`_MPInferenceClient` that can be passed to a child
74+
A :class:`_QueueInferenceClient` that can be passed to a child
15575
process as an argument to :class:`multiprocessing.Process`.
15676
"""
157-
with self._lock:
158-
actor_id = self._next_actor_id
159-
self._next_actor_id += 1
160-
response_queue: mp.Queue = self._ctx.Queue()
161-
self._response_queues[actor_id] = response_queue
162-
return _MPInferenceClient(self._request_queue, response_queue, actor_id)
163-
164-
def submit(self, td: TensorDictBase):
165-
"""Not supported -- use :meth:`client` to obtain an actor handle."""
166-
raise RuntimeError(
167-
"MPTransport.submit() is not supported. "
168-
"Call transport.client() to create an _MPInferenceClient."
169-
)
170-
171-
# -- server API -----------------------------------------------------------
172-
173-
def drain(
174-
self, max_items: int
175-
) -> tuple[list[TensorDictBase], list[tuple[int, int]]]:
176-
"""Dequeue up to *max_items* pending ``(actor_id, req_id, td)`` tuples."""
177-
items: list[TensorDictBase] = []
178-
callbacks: list[tuple[int, int]] = []
179-
for _ in range(max_items):
180-
try:
181-
actor_id, req_id, td = self._request_queue.get_nowait()
182-
items.append(td)
183-
callbacks.append((actor_id, req_id))
184-
except queue.Empty:
185-
break
186-
return items, callbacks
187-
188-
def wait_for_work(self, timeout: float) -> None:
189-
"""Block until at least one request is available or *timeout* elapses."""
190-
try:
191-
item = self._request_queue.get(timeout=timeout)
192-
# Put it back so drain() can consume it.
193-
self._request_queue.put(item)
194-
except queue.Empty:
195-
pass
196-
197-
def resolve(self, callback: tuple[int, int], result: TensorDictBase) -> None:
198-
"""Route the result to the correct actor's response queue."""
199-
actor_id, req_id = callback
200-
self._response_queues[actor_id].put((req_id, result))
201-
202-
def resolve_exception(self, callback: tuple[int, int], exc: BaseException) -> None:
203-
"""Route an exception to the correct actor's response queue."""
204-
actor_id, req_id = callback
205-
self._response_queues[actor_id].put((req_id, exc))
77+
return super().client()

0 commit comments

Comments
 (0)