Skip to content

Commit c930e3e

Browse files
committed
Update (base update)
[ghstack-poisoned]
1 parent 2a43135 commit c930e3e

File tree

3 files changed

+62
-40
lines changed

3 files changed

+62
-40
lines changed

torchrl/modules/inference_server/_queue_transport.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,10 @@ def _get_result(self, req_id: int, timeout: float | None = None) -> Any:
112112
raise queue.Empty(f"Timeout waiting for result of request {req_id}")
113113
try:
114114
rid, result = self._response_queue.get(timeout=remaining)
115-
except Exception:
116-
raise queue.Empty(f"Timeout waiting for result of request {req_id}")
115+
except Exception as e:
116+
raise queue.Empty(
117+
f"Timeout waiting for result of request {req_id}"
118+
) from e
117119
if rid == req_id:
118120
return result
119121
self._buffered[rid] = result
@@ -135,15 +137,15 @@ class QueueBasedTransport(InferenceTransport):
135137
* :meth:`_make_response_queue` -- factory for creating a new response queue.
136138
137139
.. note::
138-
``wait_for_work`` uses a blocking ``get`` followed by ``put`` to peek
139-
at the request queue. This is safe because a single server thread
140-
calls both ``wait_for_work`` and ``drain`` sequentially -- there is no
141-
concurrent consumer that could miss the re-enqueued item.
140+
``wait_for_work`` uses a blocking ``get`` to detect new work. The
141+
retrieved item is stored in ``_peeked`` and consumed by the next
142+
``drain`` call, preserving FIFO ordering.
142143
"""
143144

144145
def __init__(self):
145146
self._lock = threading.Lock()
146147
self._next_actor_id = 0
148+
self._peeked = None
147149

148150
# -- to be implemented by subclasses --------------------------------------
149151

@@ -181,7 +183,13 @@ def drain(
181183
"""Dequeue up to *max_items* pending requests (non-blocking)."""
182184
items: list[TensorDictBase] = []
183185
callbacks: list[tuple[int, int]] = []
184-
for _ in range(max_items):
186+
peeked = self._peeked
187+
if peeked is not None:
188+
self._peeked = None
189+
actor_id, req_id, td = peeked
190+
items.append(td)
191+
callbacks.append((actor_id, req_id))
192+
for _ in range(max_items - len(items)):
185193
try:
186194
actor_id, req_id, td = self._request_queue.get(block=False)
187195
except Exception:
@@ -192,10 +200,10 @@ def drain(
192200

193201
def wait_for_work(self, timeout: float) -> None:
194202
"""Block until at least one request is available or *timeout* elapses."""
203+
if self._peeked is not None:
204+
return
195205
try:
196-
item = self._request_queue.get(timeout=timeout)
197-
# Put it back so drain() can consume it.
198-
self._request_queue.put(item)
206+
self._peeked = self._request_queue.get(timeout=timeout)
199207
except Exception:
200208
pass
201209

torchrl/modules/inference_server/_server.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -121,29 +121,42 @@ def is_alive(self) -> bool:
121121

122122
@torch.no_grad()
123123
def _run(self) -> None:
124-
while not self._shutdown_event.is_set():
125-
self.transport.wait_for_work(timeout=self.timeout)
126-
124+
try:
125+
while not self._shutdown_event.is_set():
126+
self.transport.wait_for_work(timeout=self.timeout)
127+
128+
items, callbacks = self.transport.drain(self.max_batch_size)
129+
if not items:
130+
continue
131+
132+
batch = self.collate_fn(items)
133+
if self.device is not None:
134+
batch = batch.to(self.device)
135+
136+
try:
137+
results = self.model(batch).unbind(0)
138+
if len(results) != len(callbacks):
139+
raise RuntimeError(
140+
f"Model returned {len(results)} results for a "
141+
f"batch of {len(callbacks)} inputs."
142+
)
143+
for cb, res in zip(callbacks, results):
144+
self.transport.resolve(cb, res)
145+
except Exception as exc:
146+
for cb in callbacks:
147+
self.transport.resolve_exception(cb, exc)
148+
finally:
149+
self._drain_pending_on_shutdown()
150+
151+
def _drain_pending_on_shutdown(self) -> None:
152+
"""Resolve all pending requests with an error during shutdown."""
153+
shutdown_exc = RuntimeError("InferenceServer is shutting down.")
154+
while True:
127155
items, callbacks = self.transport.drain(self.max_batch_size)
128156
if not items:
129-
continue
130-
131-
batch = self.collate_fn(items)
132-
if self.device is not None:
133-
batch = batch.to(self.device)
134-
135-
try:
136-
results = self.model(batch).unbind(0)
137-
if len(results) != len(callbacks):
138-
raise RuntimeError(
139-
f"Model returned {len(results)} results for a "
140-
f"batch of {len(callbacks)} inputs."
141-
)
142-
for cb, res in zip(callbacks, results):
143-
self.transport.resolve(cb, res)
144-
except Exception as exc:
145-
for cb in callbacks:
146-
self.transport.resolve_exception(cb, exc)
157+
break
158+
for cb in callbacks:
159+
self.transport.resolve_exception(cb, shutdown_exc)
147160

148161
# -- context manager ------------------------------------------------------
149162

torchrl/modules/inference_server/_threading.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
class ThreadingTransport(InferenceTransport):
1616
"""In-process transport for actors that are threads.
1717
18-
Uses a shared list protected by a :class:`threading.Lock` as the request
19-
queue and :class:`~concurrent.futures.Future` objects for response routing.
18+
Uses a shared list protected by a :class:`threading.Condition` as the
19+
request queue and :class:`~concurrent.futures.Future` objects for response
20+
routing.
2021
2122
This is the simplest backend and is appropriate when all actors live in the
2223
same process (e.g. running in a :class:`~concurrent.futures.ThreadPoolExecutor`).
@@ -25,21 +26,20 @@ class ThreadingTransport(InferenceTransport):
2526
def __init__(self):
2627
self._queue: list[TensorDictBase] = []
2728
self._futures: list[Future] = []
28-
self._lock = threading.Lock()
29-
self._event = threading.Event()
29+
self._cond = threading.Condition(threading.Lock())
3030

3131
def submit(self, td: TensorDictBase) -> Future[TensorDictBase]:
3232
"""Enqueue a request and return a Future for the result."""
3333
fut: Future[TensorDictBase] = Future()
34-
with self._lock:
34+
with self._cond:
3535
self._queue.append(td)
3636
self._futures.append(fut)
37-
self._event.set()
37+
self._cond.notify()
3838
return fut
3939

4040
def drain(self, max_items: int) -> tuple[list[TensorDictBase], list[Future]]:
4141
"""Dequeue up to *max_items* pending requests."""
42-
with self._lock:
42+
with self._cond:
4343
n = min(len(self._queue), max_items)
4444
items = self._queue[:n]
4545
futs = self._futures[:n]
@@ -49,8 +49,9 @@ def drain(self, max_items: int) -> tuple[list[TensorDictBase], list[Future]]:
4949

5050
def wait_for_work(self, timeout: float) -> None:
5151
"""Block until at least one request is enqueued or *timeout* elapses."""
52-
self._event.wait(timeout=timeout)
53-
self._event.clear()
52+
with self._cond:
53+
if not self._queue:
54+
self._cond.wait(timeout=timeout)
5455

5556
def resolve(self, callback: Future, result: TensorDictBase) -> None:
5657
"""Set the result on the actor's Future."""

0 commit comments

Comments
 (0)