1313class _HandlerWithRequestArg (Protocol ):
1414 __name__ : str
1515
16- async def __call__ (self , request : Request , * args : Any , ** kwargs : Any ) -> Any :
17- ...
16+ async def __call__ (self , request : Request , * args : Any , ** kwargs : Any ) -> Any : ...
1817
1918
2019def _validate_signature (handler : _HandlerWithRequestArg ):
@@ -36,13 +35,15 @@ def _validate_signature(handler: _HandlerWithRequestArg):
3635_POLL_INTERVAL_S : float = 0.01
3736
3837
39- async def _disconnect_poller (request : Request , result : Any ):
38+ async def _disconnect_poller (close_event : asyncio . Event , request : Request , result : Any ):
4039 """
4140 Poll for a disconnect.
4241 If the request disconnects, stop polling and return.
4342 """
4443 while not await request .is_disconnected ():
4544 await asyncio .sleep (_POLL_INTERVAL_S )
45+ if close_event .is_set ():
46+ break
4647 return result
4748
4849
@@ -59,8 +60,9 @@ async def wrapper(request: Request, *args, **kwargs):
5960
6061 # Create two tasks:
6162 # one to poll the request and check if the client disconnected
63+ kill_poller_event = asyncio .Event ()
6264 poller_task = asyncio .create_task (
63- _disconnect_poller (request , sentinel ),
65+ _disconnect_poller (kill_poller_event , request , sentinel ),
6466 name = f"cancel_on_disconnect/poller/{ handler .__name__ } /{ id (sentinel )} " ,
6567 )
6668 # , and another which is the request handler
@@ -72,6 +74,7 @@ async def wrapper(request: Request, *args, **kwargs):
7274 done , pending = await asyncio .wait (
7375 [poller_task , handler_task ], return_when = asyncio .FIRST_COMPLETED
7476 )
77+ kill_poller_event .set ()
7578
7679 # One has completed, cancel the other
7780 for t in pending :
0 commit comments