Skip to content

Commit a80f4e1

Browse files
committed
use taskgroup for error handling
1 parent 8812a8a commit a80f4e1

File tree

1 file changed

+34
-72
lines changed

1 file changed

+34
-72
lines changed

packages/service-library/src/servicelib/fastapi/requests_decorators.py

Lines changed: 34 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -29,96 +29,58 @@ def _validate_signature(handler: _HandlerWithRequestArg):
2929

3030

3131
#
32-
# cancel_on_disconnect/disconnect_poller based
33-
# on https://github.com/RedRoserade/fastapi-disconnect-example/blob/main/app.py
32+
# cancel_on_disconnect based on TaskGroup
3433
#
3534
_POLL_INTERVAL_S: float = 0.01
3635

3736

38-
async def _disconnect_poller(close_event: asyncio.Event, request: Request, result: Any):
37+
class _ClientDisconnectedError(Exception):
38+
"""Internal exception raised by the poller task when the client disconnects."""
39+
40+
41+
async def _disconnect_poller_for_task_group(request: Request):
3942
"""
40-
Poll for a disconnect.
41-
If the request disconnects, stop polling and return.
43+
Polls for client disconnection and raises _ClientDisconnectedError if it occurs.
4244
"""
4345
while not await request.is_disconnected():
4446
await asyncio.sleep(_POLL_INTERVAL_S)
45-
if close_event.is_set():
46-
break
47-
return result
47+
raise _ClientDisconnectedError()
4848

4949

5050
def cancel_on_disconnect(handler: _HandlerWithRequestArg):
5151
"""
52-
After client disconnects, handler gets cancelled in ~<3 secs
52+
Decorator that cancels the request handler if the client disconnects.
53+
54+
Uses a TaskGroup to manage the handler and a poller task concurrently.
55+
If the client disconnects, the poller raises an exception, which is
56+
caught and translated into a 503 Service Unavailable response.
5357
"""
5458

5559
_validate_signature(handler)
5660

5761
@wraps(handler)
5862
async def wrapper(request: Request, *args, **kwargs):
59-
sentinel = object()
60-
61-
# Create two tasks:
62-
# one to poll the request and check if the client disconnected
63-
# sometimes canceling a task doesn't cancel the task immediately. If the poller task is not "killed" immediately, the client doesn't
64-
# get a response, and the request "hangs". For this reason, we use an event to signal the poller task to stop.
65-
# See: https://github.com/ITISFoundation/osparc-issues/issues/1922
66-
kill_poller_event = asyncio.Event()
67-
poller_task = asyncio.create_task(
68-
_disconnect_poller(kill_poller_event, request, sentinel),
69-
name=f"cancel_on_disconnect/poller/{handler.__name__}/{id(sentinel)}",
70-
)
71-
# , and another which is the request handler
72-
handler_task = asyncio.create_task(
73-
handler(request, *args, **kwargs),
74-
name=f"cancel_on_disconnect/handler/{handler.__name__}/{id(sentinel)}",
75-
)
76-
77-
done, pending = await asyncio.wait(
78-
[poller_task, handler_task], return_when=asyncio.FIRST_COMPLETED
79-
)
80-
kill_poller_event.set()
81-
82-
# One has completed, cancel the other
83-
for t in pending:
84-
t.cancel()
85-
86-
try:
87-
await asyncio.wait_for(t, timeout=3)
88-
89-
except asyncio.CancelledError:
90-
pass
91-
except Exception: # pylint: disable=broad-except
92-
if t is handler_task:
93-
raise
94-
finally:
95-
assert t.done() # nosec
96-
97-
# Return the result if the handler finished first
98-
if handler_task in done:
99-
assert poller_task.done() # nosec
100-
return await handler_task
101-
102-
# Otherwise, raise an exception. This is not exactly needed,
103-
# but it will prevent validation errors if your request handler
104-
# is supposed to return something.
105-
logger.warning(
106-
"Request %s %s cancelled since client %s disconnected:\n - %s\n - %s",
107-
request.method,
108-
request.url,
109-
request.client,
110-
f"{poller_task=}",
111-
f"{handler_task=}",
112-
)
113-
114-
assert poller_task.done() # nosec
115-
assert handler_task.done() # nosec
116-
117-
# NOTE: uvicorn server fails with 499
118-
raise HTTPException(
119-
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
120-
detail=f"client disconnected from {request=}",
121-
)
63+
try:
64+
async with asyncio.TaskGroup() as tg:
65+
handler_task = tg.create_task(handler(request, *args, **kwargs))
66+
tg.create_task(_disconnect_poller_for_task_group(request))
67+
68+
return handler_task.result()
69+
70+
except* _ClientDisconnectedError as eg:
71+
logger.info(
72+
"Request %s %s cancelled since client %s disconnected.",
73+
request.method,
74+
request.url,
75+
request.client,
76+
)
77+
raise HTTPException(
78+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
79+
detail="Client disconnected",
80+
) from eg
81+
82+
except* Exception as eg:
83+
raise eg.exceptions[0]
12284

12385
return wrapper
12486

0 commit comments

Comments
 (0)