Skip to content

Commit a3620be

Browse files
njhilldtrifiro
andauthored
Fail startup with root-cause exception (#156)
If the gRPC server startup fails, the http server task can also fail for some other reason when cancelled. The current logic looks for an arbitrary failed task after this and raises an exception based on that. We want to do this based on the root cause exception not the secondary one from the other task's cancellation. So that the root cause is not lost. Co-authored-by: Daniele <[email protected]>
1 parent 9fc458a commit a3620be

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

src/vllm_tgis_adapter/__main__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,19 +55,26 @@ async def start_servers(args: argparse.Namespace) -> None:
5555
# is detected, with task done and exception handled
5656
# here we just notify of that error and let servers be
5757
runtime_error = RuntimeError(
58-
"AsyncEngineClient error detected,this may be caused by an \
58+
"AsyncEngineClient error detected, this may be caused by an \
5959
unexpected error in serving a request. \
6060
Please check the logs for more details."
6161
)
6262

63+
failed_task = check_for_failed_tasks(tasks)
64+
6365
# Once either server shuts down, cancel the other
6466
for task in tasks:
6567
task.cancel()
6668

6769
# Final wait for both servers to finish
6870
await asyncio.wait(tasks)
6971

70-
check_for_failed_tasks(tasks)
72+
# Raise originally-failed task if applicable
73+
if failed_task:
74+
name, coro_name = failed_task.get_name(), failed_task.get_coro().__name__
75+
exception = failed_task.exception()
76+
raise RuntimeError(f"Failed task={name} ({coro_name})") from exception
77+
7178
if runtime_error:
7279
raise runtime_error
7380

src/vllm_tgis_adapter/utils.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,19 @@
11
import asyncio
22
from collections.abc import Iterable, Sequence
3+
from typing import Optional
34

45

5-
def check_for_failed_tasks(tasks: Iterable[asyncio.Task]) -> None:
6+
def check_for_failed_tasks(tasks: Iterable[asyncio.Task]) -> Optional[asyncio.Task]: # noqa: FA100
67
"""Check a sequence of tasks exceptions and raise the exception."""
78
for task in tasks:
89
try:
9-
exc = task.exception()
10-
except asyncio.InvalidStateError:
10+
if task.exception():
11+
return task
12+
except (asyncio.InvalidStateError, asyncio.CancelledError): # noqa: PERF203
1113
# no exception is set
12-
continue
14+
pass
1315

14-
if not exc:
15-
continue
16-
17-
name = task.get_name()
18-
coro_name = task.get_coro().__name__
19-
20-
raise RuntimeError(f"task={name} ({coro_name}) exception={exc!s}") from exc
16+
return None
2117

2218

2319
def write_termination_log(msg: str, file: str = "/dev/termination-log") -> None:

0 commit comments

Comments
 (0)