Skip to content

Commit 2d201bf

Browse files
committed
Update
[ghstack-poisoned]
2 parents 8c2309c + 6c9c982 commit 2d201bf

File tree

6 files changed

+157
-60
lines changed

6 files changed

+157
-60
lines changed

docs/source/reference/modules_inference_server.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Transport Backends
2828
:template: rl_template_noinherit.rst
2929

3030
ThreadingTransport
31+
SlotTransport
3132
MPTransport
3233
RayTransport
3334
MonarchTransport

test/test_inference_server.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,3 +981,68 @@ def tracking_collate(items):
981981

982982
# At least one batch should have >= min_batch_size items
983983
assert any(s >= min_bs for s in seen_sizes)
984+
985+
986+
# =============================================================================
987+
# Tests: bugfix regressions
988+
# =============================================================================
989+
990+
991+
class TestShutdownPendingFutures:
992+
def test_shutdown_resolves_pending_futures(self):
993+
"""Pending futures receive an exception on shutdown (no hang)."""
994+
transport = ThreadingTransport()
995+
policy = _make_policy()
996+
server = InferenceServer(policy, transport, max_batch_size=1024)
997+
server.start()
998+
futures = [
999+
transport.submit(TensorDict({"observation": torch.randn(4)}))
1000+
for _ in range(5)
1001+
]
1002+
time.sleep(0.05)
1003+
server.shutdown(timeout=5.0)
1004+
for f in futures:
1005+
try:
1006+
f.result(timeout=2.0)
1007+
except Exception:
1008+
pass # exception is acceptable; hanging is not
1009+
1010+
1011+
class TestThreadingTransportNoLostSignals:
1012+
def test_rapid_submit_no_lost_signals(self):
1013+
"""Rapid submits from many threads don't lose signals."""
1014+
transport = ThreadingTransport()
1015+
policy = _make_policy()
1016+
n = 100
1017+
with InferenceServer(policy, transport, max_batch_size=4, timeout=0.001):
1018+
client = transport.client()
1019+
with concurrent.futures.ThreadPoolExecutor(max_workers=8) as pool:
1020+
futs = [
1021+
pool.submit(
1022+
lambda: client(TensorDict({"observation": torch.randn(4)}))
1023+
)
1024+
for _ in range(n)
1025+
]
1026+
results = [f.result(timeout=10.0) for f in futs]
1027+
assert len(results) == n
1028+
for r in results:
1029+
assert "action" in r.keys()
1030+
1031+
1032+
class TestWorkerCrashPropagation:
1033+
def test_worker_crash_propagates(self):
1034+
"""If the model always fails, the collector propagates the error."""
1035+
1036+
def bad_model(td):
1037+
raise RuntimeError("model crash")
1038+
1039+
collector = AsyncBatchedCollector(
1040+
create_env_fn=[_counting_env_factory] * 2,
1041+
policy=bad_model,
1042+
frames_per_batch=10,
1043+
total_frames=100,
1044+
)
1045+
with pytest.raises(RuntimeError, match="worker thread"):
1046+
for _ in collector:
1047+
pass
1048+
collector.shutdown()

torchrl/collectors/_async_batched.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,9 @@ def _env_loop(
9292
if shutdown_event.is_set():
9393
break
9494
action_td = client(next_obs)
95-
except Exception:
95+
except Exception as exc:
9696
if not shutdown_event.is_set():
97-
raise
97+
result_queue.put(exc)
9898

9999

100100
class AsyncBatchedCollector(BaseCollector):
@@ -367,6 +367,14 @@ def policy(self) -> Callable:
367367
# Rollout: drain the result queue
368368
# ------------------------------------------------------------------
369369

370+
@staticmethod
371+
def _check_worker_result(item):
372+
"""Re-raise exceptions propagated from coordinator threads."""
373+
if isinstance(item, BaseException):
374+
raise RuntimeError(
375+
"A collector worker thread raised an exception."
376+
) from item
377+
370378
def _rollout_frames(self) -> TensorDictBase:
371379
"""Drain ``frames_per_batch`` transitions from the workers."""
372380
rq = self._result_queue
@@ -376,6 +384,7 @@ def _rollout_frames(self) -> TensorDictBase:
376384
while collected < self.frames_per_batch:
377385
# Block for at least one transition
378386
td = rq.get()
387+
self._check_worker_result(td)
379388
transitions.append(td)
380389
collected += td.numel()
381390
# Batch-drain any additional items already in the queue
@@ -399,6 +408,7 @@ def _rollout_yield_trajs(self) -> TensorDictBase:
399408

400409
while not self._trajectory_queue:
401410
td = rq.get()
411+
self._check_worker_result(td)
402412
env_id = 0
403413
eid = td.get(_ENV_IDX_KEY, default=None)
404414
if eid is not None:

torchrl/modules/inference_server/_queue_transport.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,10 @@ def _get_result(self, req_id: int, timeout: float | None = None) -> Any:
112112
raise queue.Empty(f"Timeout waiting for result of request {req_id}")
113113
try:
114114
rid, result = self._response_queue.get(timeout=remaining)
115-
except Exception:
116-
raise queue.Empty(f"Timeout waiting for result of request {req_id}")
115+
except Exception as e:
116+
raise queue.Empty(
117+
f"Timeout waiting for result of request {req_id}"
118+
) from e
117119
if rid == req_id:
118120
return result
119121
self._buffered[rid] = result
@@ -135,15 +137,15 @@ class QueueBasedTransport(InferenceTransport):
135137
* :meth:`_make_response_queue` -- factory for creating a new response queue.
136138
137139
.. note::
138-
``wait_for_work`` uses a blocking ``get`` followed by ``put`` to peek
139-
at the request queue. This is safe because a single server thread
140-
calls both ``wait_for_work`` and ``drain`` sequentially -- there is no
141-
concurrent consumer that could miss the re-enqueued item.
140+
``wait_for_work`` uses a blocking ``get`` to detect new work. The
141+
retrieved item is stored in ``_peeked`` and consumed by the next
142+
``drain`` call, preserving FIFO ordering.
142143
"""
143144

144145
def __init__(self):
145146
self._lock = threading.Lock()
146147
self._next_actor_id = 0
148+
self._peeked = None
147149

148150
# -- to be implemented by subclasses --------------------------------------
149151

@@ -181,7 +183,13 @@ def drain(
181183
"""Dequeue up to *max_items* pending requests (non-blocking)."""
182184
items: list[TensorDictBase] = []
183185
callbacks: list[tuple[int, int]] = []
184-
for _ in range(max_items):
186+
peeked = self._peeked
187+
if peeked is not None:
188+
self._peeked = None
189+
actor_id, req_id, td = peeked
190+
items.append(td)
191+
callbacks.append((actor_id, req_id))
192+
for _ in range(max_items - len(items)):
185193
try:
186194
actor_id, req_id, td = self._request_queue.get(block=False)
187195
except Exception:
@@ -192,10 +200,10 @@ def drain(
192200

193201
def wait_for_work(self, timeout: float) -> None:
194202
"""Block until at least one request is available or *timeout* elapses."""
203+
if self._peeked is not None:
204+
return
195205
try:
196-
item = self._request_queue.get(timeout=timeout)
197-
# Put it back so drain() can consume it.
198-
self._request_queue.put(item)
206+
self._peeked = self._request_queue.get(timeout=timeout)
199207
except Exception:
200208
pass
201209

torchrl/modules/inference_server/_server.py

Lines changed: 51 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -160,47 +160,59 @@ def _poll_weight_update(self) -> None:
160160
def _run(self) -> None:
161161
self._init_weight_sync()
162162

163-
while not self._shutdown_event.is_set():
164-
# Poll for weight updates between batches (non-blocking)
165-
self._poll_weight_update()
166-
167-
self.transport.wait_for_work(timeout=self.timeout)
168-
163+
try:
164+
while not self._shutdown_event.is_set():
165+
self._poll_weight_update()
166+
167+
self.transport.wait_for_work(timeout=self.timeout)
168+
169+
items, callbacks = self.transport.drain(self.max_batch_size)
170+
if not items:
171+
continue
172+
173+
# Accumulate up to min_batch_size (or until timeout expires)
174+
if len(items) < self.min_batch_size:
175+
deadline = time.monotonic() + self.timeout
176+
while len(items) < self.min_batch_size:
177+
remaining = deadline - time.monotonic()
178+
if remaining <= 0:
179+
break
180+
self.transport.wait_for_work(timeout=remaining)
181+
more_items, more_cbs = self.transport.drain(
182+
self.max_batch_size - len(items)
183+
)
184+
items.extend(more_items)
185+
callbacks.extend(more_cbs)
186+
187+
batch = self.collate_fn(items)
188+
if self.device is not None:
189+
batch = batch.to(self.device)
190+
191+
try:
192+
with self._model_lock:
193+
results = self.model(batch).unbind(0)
194+
if len(results) != len(callbacks):
195+
raise RuntimeError(
196+
f"Model returned {len(results)} results for a "
197+
f"batch of {len(callbacks)} inputs."
198+
)
199+
for cb, res in zip(callbacks, results):
200+
self.transport.resolve(cb, res)
201+
except Exception as exc:
202+
for cb in callbacks:
203+
self.transport.resolve_exception(cb, exc)
204+
finally:
205+
self._drain_pending_on_shutdown()
206+
207+
def _drain_pending_on_shutdown(self) -> None:
208+
"""Resolve all pending requests with an error during shutdown."""
209+
shutdown_exc = RuntimeError("InferenceServer is shutting down.")
210+
while True:
169211
items, callbacks = self.transport.drain(self.max_batch_size)
170212
if not items:
171-
continue
172-
173-
# Accumulate up to min_batch_size (or until timeout expires)
174-
if len(items) < self.min_batch_size:
175-
deadline = time.monotonic() + self.timeout
176-
while len(items) < self.min_batch_size:
177-
remaining = deadline - time.monotonic()
178-
if remaining <= 0:
179-
break
180-
self.transport.wait_for_work(timeout=remaining)
181-
more_items, more_cbs = self.transport.drain(
182-
self.max_batch_size - len(items)
183-
)
184-
items.extend(more_items)
185-
callbacks.extend(more_cbs)
186-
187-
batch = self.collate_fn(items)
188-
if self.device is not None:
189-
batch = batch.to(self.device)
190-
191-
try:
192-
with self._model_lock:
193-
results = self.model(batch).unbind(0)
194-
if len(results) != len(callbacks):
195-
raise RuntimeError(
196-
f"Model returned {len(results)} results for a "
197-
f"batch of {len(callbacks)} inputs."
198-
)
199-
for cb, res in zip(callbacks, results):
200-
self.transport.resolve(cb, res)
201-
except Exception as exc:
202-
for cb in callbacks:
203-
self.transport.resolve_exception(cb, exc)
213+
break
214+
for cb in callbacks:
215+
self.transport.resolve_exception(cb, shutdown_exc)
204216

205217
# -- context manager ------------------------------------------------------
206218

torchrl/modules/inference_server/_threading.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
class ThreadingTransport(InferenceTransport):
1616
"""In-process transport for actors that are threads.
1717
18-
Uses a shared list protected by a :class:`threading.Lock` as the request
19-
queue and :class:`~concurrent.futures.Future` objects for response routing.
18+
Uses a shared list protected by a :class:`threading.Condition` as the
19+
request queue and :class:`~concurrent.futures.Future` objects for response
20+
routing.
2021
2122
This is the simplest backend and is appropriate when all actors live in the
2223
same process (e.g. running in a :class:`~concurrent.futures.ThreadPoolExecutor`).
@@ -25,21 +26,20 @@ class ThreadingTransport(InferenceTransport):
2526
def __init__(self):
2627
self._queue: list[TensorDictBase] = []
2728
self._futures: list[Future] = []
28-
self._lock = threading.Lock()
29-
self._event = threading.Event()
29+
self._cond = threading.Condition(threading.Lock())
3030

3131
def submit(self, td: TensorDictBase) -> Future[TensorDictBase]:
3232
"""Enqueue a request and return a Future for the result."""
3333
fut: Future[TensorDictBase] = Future()
34-
with self._lock:
34+
with self._cond:
3535
self._queue.append(td)
3636
self._futures.append(fut)
37-
self._event.set()
37+
self._cond.notify()
3838
return fut
3939

4040
def drain(self, max_items: int) -> tuple[list[TensorDictBase], list[Future]]:
4141
"""Dequeue up to *max_items* pending requests."""
42-
with self._lock:
42+
with self._cond:
4343
n = min(len(self._queue), max_items)
4444
items = self._queue[:n]
4545
futs = self._futures[:n]
@@ -49,8 +49,9 @@ def drain(self, max_items: int) -> tuple[list[TensorDictBase], list[Future]]:
4949

5050
def wait_for_work(self, timeout: float) -> None:
5151
"""Block until at least one request is enqueued or *timeout* elapses."""
52-
self._event.wait(timeout=timeout)
53-
self._event.clear()
52+
with self._cond:
53+
if not self._queue:
54+
self._cond.wait(timeout=timeout)
5455

5556
def resolve(self, callback: Future, result: TensorDictBase) -> None:
5657
"""Set the result on the actor's Future."""

0 commit comments

Comments
 (0)