Skip to content

Commit d6068bb

Browse files
committed
add check_for_failed_tasks
1 parent 0153238 commit d6068bb

File tree

2 files changed

+30
-15
lines changed

2 files changed

+30
-15
lines changed

src/vllm_tgis_adapter/__main__.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import contextlib
45
import signal
56
from concurrent.futures import FIRST_EXCEPTION
67
from typing import TYPE_CHECKING
@@ -17,6 +18,7 @@
1718
from .http import run_http_server
1819
from .logging import init_logger
1920
from .tgis_utils.args import EnvVarArgumentParser, add_tgis_args, postprocess_tgis_args
21+
from .utils import check_for_failed_tasks
2022

2123
if TYPE_CHECKING:
2224
import argparse
@@ -54,23 +56,16 @@ async def override_signal_handler() -> None:
5456

5557
tasks.append(loop.create_task(override_signal_handler()))
5658

57-
done, pending = await asyncio.wait(
58-
tasks,
59-
return_when=FIRST_EXCEPTION,
60-
)
61-
for task in pending:
62-
task.cancel()
59+
with contextlib.suppress(asyncio.CancelledError):
60+
await asyncio.wait(
61+
tasks,
62+
return_when=FIRST_EXCEPTION,
63+
)
6364

64-
while done:
65-
task = done.pop()
66-
exc = task.exception()
67-
if not exc:
68-
continue
69-
70-
name = task.get_name()
71-
coro_name = task.get_coro().__name__
65+
for task in tasks:
66+
task.cancel()
7267

73-
raise RuntimeError(f"task={name} ({coro_name})") from exc
68+
check_for_failed_tasks(tasks)
7469

7570

7671
if __name__ == "__main__":

src/vllm_tgis_adapter/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import asyncio
2+
from collections.abc import Iterable
3+
4+
5+
def check_for_failed_tasks(tasks: Iterable[asyncio.Task]) -> None:
6+
"""Check a sequence of tasks exceptions and raise the exception."""
7+
for task in tasks:
8+
try:
9+
exc = task.exception()
10+
except asyncio.InvalidStateError:
11+
# no exception is set
12+
continue
13+
14+
if not exc:
15+
continue
16+
17+
name = task.get_name()
18+
coro_name = task.get_coro().__name__
19+
20+
raise RuntimeError(f"task={name} ({coro_name})") from exc

0 commit comments

Comments
 (0)