Skip to content

Commit 8c2309c

Browse files
committed
Update
[ghstack-poisoned]
2 parents 88ff619 + 079767e commit 8c2309c

File tree

4 files changed

+27
-122
lines changed

4 files changed

+27
-122
lines changed

torchrl/modules/inference_server/_monarch.py

Lines changed: 4 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,49 +4,12 @@
44
# LICENSE file in the root directory of this source tree.
55
from __future__ import annotations
66

7-
import threading
8-
97
from torchrl.modules.inference_server._queue_transport import (
108
_QueueInferenceClient,
119
QueueBasedTransport,
1210
)
1311

1412

15-
class _MonarchRequestQueue:
16-
"""Wrapper around ``MonarchQueue`` that signals a :class:`threading.Event` on put.
17-
18-
Also adapts the Monarch queue API (``get(block=False)``) to the standard
19-
``get_nowait()`` expected by :class:`QueueBasedTransport`.
20-
"""
21-
22-
def __init__(self, monarch_queue, has_work: threading.Event):
23-
self._queue = monarch_queue
24-
self._has_work = has_work
25-
26-
def put(self, item):
27-
self._queue.put(item)
28-
self._has_work.set()
29-
30-
def get(self, timeout=None):
31-
return self._queue.get(timeout=timeout)
32-
33-
def get_nowait(self):
34-
return self._queue.get(block=False)
35-
36-
37-
class _MonarchResponseQueue:
38-
"""Thin wrapper adapting the MonarchQueue get API."""
39-
40-
def __init__(self, monarch_queue):
41-
self._queue = monarch_queue
42-
43-
def put(self, item):
44-
self._queue.put(item)
45-
46-
def get(self, timeout=None):
47-
return self._queue.get(timeout=timeout)
48-
49-
5013
class MonarchTransport(QueueBasedTransport):
5114
"""Transport using Monarch for distributed inference on GPU clusters.
5215
@@ -74,15 +37,12 @@ def __init__(self, *, max_queue_size: int = 1000):
7437
"Monarch is required for MonarchTransport. "
7538
"Install it following the Monarch documentation."
7639
)
77-
self._has_work = threading.Event()
78-
self._request_queue = _MonarchRequestQueue(
79-
MonarchQueue(maxsize=max_queue_size), self._has_work
80-
)
81-
self._response_queues: dict[int, _MonarchResponseQueue] = {}
40+
self._request_queue = MonarchQueue(maxsize=max_queue_size)
41+
self._response_queues: dict[int, MonarchQueue] = {}
8242
self._MonarchQueue = MonarchQueue
8343

84-
def _make_response_queue(self) -> _MonarchResponseQueue:
85-
return _MonarchResponseQueue(self._MonarchQueue(maxsize=1000))
44+
def _make_response_queue(self):
45+
return self._MonarchQueue(maxsize=1000)
8646

8747
def client(self) -> _QueueInferenceClient:
8848
"""Create an actor-side client with a dedicated response queue.

torchrl/modules/inference_server/_mp.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,36 +5,13 @@
55
from __future__ import annotations
66

77
import multiprocessing as mp
8-
import threading
98

109
from torchrl.modules.inference_server._queue_transport import (
1110
_QueueInferenceClient,
1211
QueueBasedTransport,
1312
)
1413

1514

16-
class _MPRequestQueue:
17-
"""Wrapper around ``mp.Queue`` that signals a :class:`threading.Event` on put.
18-
19-
This avoids the get-then-put anti-pattern in ``wait_for_work``: instead of
20-
consuming an item just to peek, callers wait on the event.
21-
"""
22-
23-
def __init__(self, ctx: mp.context.BaseContext, has_work: threading.Event):
24-
self._queue: mp.Queue = ctx.Queue()
25-
self._has_work = has_work
26-
27-
def put(self, item):
28-
self._queue.put(item)
29-
self._has_work.set()
30-
31-
def get(self, timeout=None):
32-
return self._queue.get(timeout=timeout)
33-
34-
def get_nowait(self):
35-
return self._queue.get_nowait()
36-
37-
3815
class MPTransport(QueueBasedTransport):
3916
"""Cross-process transport using :mod:`multiprocessing` queues.
4017
@@ -58,8 +35,7 @@ class MPTransport(QueueBasedTransport):
5835
def __init__(self, ctx: mp.context.BaseContext | None = None):
5936
super().__init__()
6037
self._ctx = ctx if ctx is not None else mp.get_context("spawn")
61-
self._has_work = threading.Event()
62-
self._request_queue = _MPRequestQueue(self._ctx, self._has_work)
38+
self._request_queue: mp.Queue = self._ctx.Queue()
6339
self._response_queues: dict[int, mp.Queue] = {}
6440

6541
def _make_response_queue(self) -> mp.Queue:

torchrl/modules/inference_server/_queue_transport.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ class _QueueInferenceClient:
7474
request-id.
7575
7676
Args:
77-
request_queue: the shared request queue.
77+
request_queue: the shared request queue (any object with ``.put()``).
7878
response_queue: this client's dedicated response queue.
7979
actor_id: the unique identifier assigned by the transport.
8080
"""
@@ -122,18 +122,23 @@ def _get_result(self, req_id: int, timeout: float | None = None) -> Any:
122122
class QueueBasedTransport(InferenceTransport):
123123
"""Base class for transports that use a request queue and per-actor response queues.
124124
125-
Subclasses must set the following attributes before calling ``super().__init__()``:
125+
Subclasses must set the following attributes in ``__init__`` (before or
126+
after calling ``super().__init__()``):
126127
127-
* ``_request_queue`` -- the shared request queue (any object with ``.put()``,
128-
``.get(timeout=...)``, and ``.get_nowait()`` / ``.get(block=False)``).
128+
* ``_request_queue`` -- the shared request queue (any object with
129+
``.put()``, ``.get(timeout=...)``, and ``.get(block=False)``).
129130
* ``_response_queues`` -- a ``dict[int, <queue>]`` mapping actor ids to
130131
per-actor response queues.
131-
* ``_has_work`` -- a :class:`threading.Event` that is set whenever a new
132-
request is enqueued (used for non-blocking ``wait_for_work``).
133132
134133
Subclasses must implement:
135134
136135
* :meth:`_make_response_queue` -- factory for creating a new response queue.
136+
137+
.. note::
138+
``wait_for_work`` uses a blocking ``get`` followed by ``put`` to peek
139+
at the request queue. This is safe because a single server thread
140+
calls both ``wait_for_work`` and ``drain`` sequentially -- there is no
141+
concurrent consumer that could miss the re-enqueued item.
137142
"""
138143

139144
def __init__(self):
@@ -178,7 +183,7 @@ def drain(
178183
callbacks: list[tuple[int, int]] = []
179184
for _ in range(max_items):
180185
try:
181-
actor_id, req_id, td = self._request_queue.get_nowait()
186+
actor_id, req_id, td = self._request_queue.get(block=False)
182187
except Exception:
183188
break
184189
items.append(td)
@@ -187,8 +192,12 @@ def drain(
187192

188193
def wait_for_work(self, timeout: float) -> None:
189194
"""Block until at least one request is available or *timeout* elapses."""
190-
self._has_work.wait(timeout=timeout)
191-
self._has_work.clear()
195+
try:
196+
item = self._request_queue.get(timeout=timeout)
197+
# Put it back so drain() can consume it.
198+
self._request_queue.put(item)
199+
except Exception:
200+
pass
192201

193202
def resolve(self, callback: tuple[int, int], result: TensorDictBase) -> None:
194203
"""Route the result to the correct actor's response queue."""

torchrl/modules/inference_server/_ray.py

Lines changed: 4 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,49 +4,12 @@
44
# LICENSE file in the root directory of this source tree.
55
from __future__ import annotations
66

7-
import threading
8-
97
from torchrl.modules.inference_server._queue_transport import (
108
_QueueInferenceClient,
119
QueueBasedTransport,
1210
)
1311

1412

15-
class _RayRequestQueue:
16-
"""Wrapper around ``ray.util.queue.Queue`` that signals a :class:`threading.Event` on put.
17-
18-
Also adapts the Ray queue API (``get(block=False)``) to the standard
19-
``get_nowait()`` expected by :class:`QueueBasedTransport`.
20-
"""
21-
22-
def __init__(self, ray_queue, has_work: threading.Event):
23-
self._queue = ray_queue
24-
self._has_work = has_work
25-
26-
def put(self, item):
27-
self._queue.put(item)
28-
self._has_work.set()
29-
30-
def get(self, timeout=None):
31-
return self._queue.get(timeout=timeout)
32-
33-
def get_nowait(self):
34-
return self._queue.get(block=False)
35-
36-
37-
class _RayResponseQueue:
38-
"""Thin wrapper around ``ray.util.queue.Queue`` that adapts the get API."""
39-
40-
def __init__(self, ray_queue):
41-
self._queue = ray_queue
42-
43-
def put(self, item):
44-
self._queue.put(item)
45-
46-
def get(self, timeout=None):
47-
return self._queue.get(timeout=timeout)
48-
49-
5013
class RayTransport(QueueBasedTransport):
5114
"""Transport using Ray queues for distributed inference.
5215
@@ -77,15 +40,12 @@ def __init__(self, *, max_queue_size: int = 1000):
7740
raise ImportError(
7841
"Ray is required for RayTransport. Install it with: pip install ray"
7942
)
80-
self._has_work = threading.Event()
81-
self._request_queue = _RayRequestQueue(
82-
ray.util.queue.Queue(maxsize=max_queue_size), self._has_work
83-
)
84-
self._response_queues: dict[int, _RayResponseQueue] = {}
43+
self._request_queue = ray.util.queue.Queue(maxsize=max_queue_size)
44+
self._response_queues: dict[int, ray.util.queue.Queue] = {}
8545
self._ray_queue_module = ray.util.queue
8646

87-
def _make_response_queue(self) -> _RayResponseQueue:
88-
return _RayResponseQueue(self._ray_queue_module.Queue(maxsize=1000))
47+
def _make_response_queue(self):
48+
return self._ray_queue_module.Queue(maxsize=1000)
8949

9050
def client(self) -> _QueueInferenceClient:
9151
"""Create an actor-side client with a dedicated Ray response queue.

0 commit comments

Comments
 (0)