@@ -38,12 +38,16 @@ class _ClientDisconnectedError(Exception):
3838 """Internal exception raised by the poller task when the client disconnects."""
3939
4040
41- async def _disconnect_poller_for_task_group (request : Request ):
41+ async def _disconnect_poller_for_task_group (
42+ close_event : asyncio .Event , request : Request
43+ ):
4244 """
4345 Polls for client disconnection and raises _ClientDisconnectedError if it occurs.
4446 """
4547 while not await request .is_disconnected ():
4648 await asyncio .sleep (_POLL_INTERVAL_S )
49+ if close_event .is_set ():
50+ return
4751 raise _ClientDisconnectedError ()
4852
4953
@@ -61,16 +65,20 @@ def cancel_on_disconnect(handler: _HandlerWithRequestArg):
6165 @wraps (handler )
6266 async def wrapper (request : Request , * args , ** kwargs ):
6367 sentinel = object ()
68+ kill_poller_task_event = asyncio .Event ()
6469 try :
6570 async with asyncio .TaskGroup () as tg :
71+
72+ tg .create_task (
73+ _disconnect_poller_for_task_group (kill_poller_task_event , request ),
74+ name = f"cancel_on_disconnect/poller/{ handler .__name__ } /{ id (sentinel )} " ,
75+ )
6676 handler_task = tg .create_task (
6777 handler (request , * args , ** kwargs ),
6878 name = f"cancel_on_disconnect/handler/{ handler .__name__ } /{ id (sentinel )} " ,
6979 )
70- tg .create_task (
71- _disconnect_poller_for_task_group (request ),
72- name = f"cancel_on_disconnect/poller/{ handler .__name__ } /{ id (sentinel )} " ,
73- )
80+ await handler_task
81+ kill_poller_task_event .set ()
7482
7583 return handler_task .result ()
7684
0 commit comments