@@ -29,96 +29,58 @@ def _validate_signature(handler: _HandlerWithRequestArg):
2929
3030
3131#
32- # cancel_on_disconnect/disconnect_poller based
33- # on https://github.com/RedRoserade/fastapi-disconnect-example/blob/main/app.py
32+ # cancel_on_disconnect based on TaskGroup
3433#
3534_POLL_INTERVAL_S : float = 0.01
3635
3736
38- async def _disconnect_poller (close_event : asyncio .Event , request : Request , result : Any ):
37+ class _ClientDisconnectedError (Exception ):
38+ """Internal exception raised by the poller task when the client disconnects."""
39+
40+
41+ async def _disconnect_poller_for_task_group (request : Request ):
3942 """
40- Poll for a disconnect.
41- If the request disconnects, stop polling and return.
43+ Polls for client disconnection and raises _ClientDisconnectedError if it occurs.
4244 """
4345 while not await request .is_disconnected ():
4446 await asyncio .sleep (_POLL_INTERVAL_S )
45- if close_event .is_set ():
46- break
47- return result
47+ raise _ClientDisconnectedError ()
4848
4949
5050def cancel_on_disconnect (handler : _HandlerWithRequestArg ):
5151 """
52- After client disconnects, handler gets cancelled in ~<3 secs
52+ Decorator that cancels the request handler if the client disconnects.
53+
54+ Uses a TaskGroup to manage the handler and a poller task concurrently.
55+ If the client disconnects, the poller raises an exception, which is
56+ caught and translated into a 503 Service Unavailable response.
5357 """
5458
5559 _validate_signature (handler )
5660
5761 @wraps (handler )
5862 async def wrapper (request : Request , * args , ** kwargs ):
59- sentinel = object ()
60-
61- # Create two tasks:
62- # one to poll the request and check if the client disconnected
63- # sometimes canceling a task doesn't cancel the task immediately. If the poller task is not "killed" immediately, the client doesn't
64- # get a response, and the request "hangs". For this reason, we use an event to signal the poller task to stop.
65- # See: https://github.com/ITISFoundation/osparc-issues/issues/1922
66- kill_poller_event = asyncio .Event ()
67- poller_task = asyncio .create_task (
68- _disconnect_poller (kill_poller_event , request , sentinel ),
69- name = f"cancel_on_disconnect/poller/{ handler .__name__ } /{ id (sentinel )} " ,
70- )
71- # , and another which is the request handler
72- handler_task = asyncio .create_task (
73- handler (request , * args , ** kwargs ),
74- name = f"cancel_on_disconnect/handler/{ handler .__name__ } /{ id (sentinel )} " ,
75- )
76-
77- done , pending = await asyncio .wait (
78- [poller_task , handler_task ], return_when = asyncio .FIRST_COMPLETED
79- )
80- kill_poller_event .set ()
81-
82- # One has completed, cancel the other
83- for t in pending :
84- t .cancel ()
85-
86- try :
87- await asyncio .wait_for (t , timeout = 3 )
88-
89- except asyncio .CancelledError :
90- pass
91- except Exception : # pylint: disable=broad-except
92- if t is handler_task :
93- raise
94- finally :
95- assert t .done () # nosec
96-
97- # Return the result if the handler finished first
98- if handler_task in done :
99- assert poller_task .done () # nosec
100- return await handler_task
101-
102- # Otherwise, raise an exception. This is not exactly needed,
103- # but it will prevent validation errors if your request handler
104- # is supposed to return something.
105- logger .warning (
106- "Request %s %s cancelled since client %s disconnected:\n - %s\n - %s" ,
107- request .method ,
108- request .url ,
109- request .client ,
110- f"{ poller_task = } " ,
111- f"{ handler_task = } " ,
112- )
113-
114- assert poller_task .done () # nosec
115- assert handler_task .done () # nosec
116-
117- # NOTE: uvicorn server fails with 499
118- raise HTTPException (
119- status_code = status .HTTP_503_SERVICE_UNAVAILABLE ,
120- detail = f"client disconnected from { request = } " ,
121- )
63+ try :
64+ async with asyncio .TaskGroup () as tg :
65+ handler_task = tg .create_task (handler (request , * args , ** kwargs ))
66+ tg .create_task (_disconnect_poller_for_task_group (request ))
67+
68+ return handler_task .result ()
69+
70+ except* _ClientDisconnectedError as eg :
71+ logger .info (
72+ "Request %s %s cancelled since client %s disconnected." ,
73+ request .method ,
74+ request .url ,
75+ request .client ,
76+ )
77+ raise HTTPException (
78+ status_code = status .HTTP_503_SERVICE_UNAVAILABLE ,
79+ detail = "Client disconnected" ,
80+ ) from eg
81+
82+ except* Exception as eg :
83+ raise eg .exceptions [0 ]
12284
12385 return wrapper
12486
0 commit comments