Skip to content

Commit db3d441

Browse files
committed
Update (base update)
[ghstack-poisoned]
1 parent 34c13c7 commit db3d441

File tree

2 files changed

+19
-34
lines changed

2 files changed

+19
-34
lines changed

torchrl/modules/inference_server/_mp.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,36 +5,13 @@
55
from __future__ import annotations
66

77
import multiprocessing as mp
8-
import threading
98

109
from torchrl.modules.inference_server._queue_transport import (
1110
_QueueInferenceClient,
1211
QueueBasedTransport,
1312
)
1413

1514

16-
class _MPRequestQueue:
17-
"""Wrapper around ``mp.Queue`` that signals a :class:`threading.Event` on put.
18-
19-
This avoids the get-then-put anti-pattern in ``wait_for_work``: instead of
20-
consuming an item just to peek, callers wait on the event.
21-
"""
22-
23-
def __init__(self, ctx: mp.context.BaseContext, has_work: threading.Event):
24-
self._queue: mp.Queue = ctx.Queue()
25-
self._has_work = has_work
26-
27-
def put(self, item):
28-
self._queue.put(item)
29-
self._has_work.set()
30-
31-
def get(self, timeout=None):
32-
return self._queue.get(timeout=timeout)
33-
34-
def get_nowait(self):
35-
return self._queue.get_nowait()
36-
37-
3815
class MPTransport(QueueBasedTransport):
3916
"""Cross-process transport using :mod:`multiprocessing` queues.
4017
@@ -58,8 +35,7 @@ class MPTransport(QueueBasedTransport):
5835
def __init__(self, ctx: mp.context.BaseContext | None = None):
5936
super().__init__()
6037
self._ctx = ctx if ctx is not None else mp.get_context("spawn")
61-
self._has_work = threading.Event()
62-
self._request_queue = _MPRequestQueue(self._ctx, self._has_work)
38+
self._request_queue: mp.Queue = self._ctx.Queue()
6339
self._response_queues: dict[int, mp.Queue] = {}
6440

6541
def _make_response_queue(self) -> mp.Queue:

torchrl/modules/inference_server/_queue_transport.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ class _QueueInferenceClient:
7474
request-id.
7575
7676
Args:
77-
request_queue: the shared request queue.
77+
request_queue: the shared request queue (any object with ``.put()``).
7878
response_queue: this client's dedicated response queue.
7979
actor_id: the unique identifier assigned by the transport.
8080
"""
@@ -122,18 +122,23 @@ def _get_result(self, req_id: int, timeout: float | None = None) -> Any:
122122
class QueueBasedTransport(InferenceTransport):
123123
"""Base class for transports that use a request queue and per-actor response queues.
124124
125-
Subclasses must set the following attributes before calling ``super().__init__()``:
125+
Subclasses must set the following attributes in ``__init__`` (before or
126+
after calling ``super().__init__()``):
126127
127-
* ``_request_queue`` -- the shared request queue (any object with ``.put()``,
128-
``.get(timeout=...)``, and ``.get_nowait()`` / ``.get(block=False)``).
128+
* ``_request_queue`` -- the shared request queue (any object with
129+
``.put()``, ``.get(timeout=...)``, and ``.get(block=False)``).
129130
* ``_response_queues`` -- a ``dict[int, <queue>]`` mapping actor ids to
130131
per-actor response queues.
131-
* ``_has_work`` -- a :class:`threading.Event` that is set whenever a new
132-
request is enqueued (used for non-blocking ``wait_for_work``).
133132
134133
Subclasses must implement:
135134
136135
* :meth:`_make_response_queue` -- factory for creating a new response queue.
136+
137+
.. note::
138+
``wait_for_work`` uses a blocking ``get`` followed by ``put`` to peek
139+
at the request queue. This is safe because a single server thread
140+
calls both ``wait_for_work`` and ``drain`` sequentially -- there is no
141+
concurrent consumer that could miss the re-enqueued item.
137142
"""
138143

139144
def __init__(self):
@@ -178,7 +183,7 @@ def drain(
178183
callbacks: list[tuple[int, int]] = []
179184
for _ in range(max_items):
180185
try:
181-
actor_id, req_id, td = self._request_queue.get_nowait()
186+
actor_id, req_id, td = self._request_queue.get(block=False)
182187
except Exception:
183188
break
184189
items.append(td)
@@ -187,8 +192,12 @@ def drain(
187192

188193
def wait_for_work(self, timeout: float) -> None:
189194
"""Block until at least one request is available or *timeout* elapses."""
190-
self._has_work.wait(timeout=timeout)
191-
self._has_work.clear()
195+
try:
196+
item = self._request_queue.get(timeout=timeout)
197+
# Put it back so drain() can consume it.
198+
self._request_queue.put(item)
199+
except Exception:
200+
pass
192201

193202
def resolve(self, callback: tuple[int, int], result: TensorDictBase) -> None:
194203
"""Route the result to the correct actor's response queue."""

0 commit comments

Comments
 (0)