Skip to content

Commit 579feae

Browse files
committed
Update
[ghstack-poisoned]
1 parent 3b86bdc commit 579feae

File tree

1 file changed

+34
-21
lines changed

1 file changed

+34
-21
lines changed

torchrl/modules/inference_server/_server.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -121,29 +121,42 @@ def is_alive(self) -> bool:
121121

122122
@torch.no_grad()
123123
def _run(self) -> None:
124-
while not self._shutdown_event.is_set():
125-
self.transport.wait_for_work(timeout=self.timeout)
126-
124+
try:
125+
while not self._shutdown_event.is_set():
126+
self.transport.wait_for_work(timeout=self.timeout)
127+
128+
items, callbacks = self.transport.drain(self.max_batch_size)
129+
if not items:
130+
continue
131+
132+
batch = self.collate_fn(items)
133+
if self.device is not None:
134+
batch = batch.to(self.device)
135+
136+
try:
137+
results = self.model(batch).unbind(0)
138+
if len(results) != len(callbacks):
139+
raise RuntimeError(
140+
f"Model returned {len(results)} results for a "
141+
f"batch of {len(callbacks)} inputs."
142+
)
143+
for cb, res in zip(callbacks, results):
144+
self.transport.resolve(cb, res)
145+
except Exception as exc:
146+
for cb in callbacks:
147+
self.transport.resolve_exception(cb, exc)
148+
finally:
149+
self._drain_pending_on_shutdown()
150+
151+
def _drain_pending_on_shutdown(self) -> None:
152+
"""Resolve all pending requests with an error during shutdown."""
153+
shutdown_exc = RuntimeError("InferenceServer is shutting down.")
154+
while True:
127155
items, callbacks = self.transport.drain(self.max_batch_size)
128156
if not items:
129-
continue
130-
131-
batch = self.collate_fn(items)
132-
if self.device is not None:
133-
batch = batch.to(self.device)
134-
135-
try:
136-
results = self.model(batch).unbind(0)
137-
if len(results) != len(callbacks):
138-
raise RuntimeError(
139-
f"Model returned {len(results)} results for a "
140-
f"batch of {len(callbacks)} inputs."
141-
)
142-
for cb, res in zip(callbacks, results):
143-
self.transport.resolve(cb, res)
144-
except Exception as exc:
145-
for cb in callbacks:
146-
self.transport.resolve_exception(cb, exc)
157+
break
158+
for cb in callbacks:
159+
self.transport.resolve_exception(cb, shutdown_exc)
147160

148161
# -- context manager ------------------------------------------------------
149162

0 commit comments

Comments
 (0)