Skip to content

Commit aed99b7

Browse files
authored
🐛 Is638/fixes cancel on disconnect (ITISFoundation#3164)
1 parent 00a4f25 commit aed99b7

File tree

25 files changed

+371
-391
lines changed

25 files changed

+371
-391
lines changed

packages/service-library/requirements/_fastapi.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
async-asgi-testclient # replacement for fastapi.testclient.TestClient [see b) below]
1010
fastapi
1111
fastapi_contrib[jaegertracing]
12-
12+
uvicorn
1313

1414
# NOTE: What test client to use for fastapi-based apps?
1515
#

packages/service-library/requirements/_fastapi.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,16 @@ certifi==2022.6.15
1212
# via requests
1313
charset-normalizer==2.0.12
1414
# via requests
15+
click==8.1.3
16+
# via uvicorn
1517
fastapi==0.76.0
1618
# via
1719
# -r requirements/_fastapi.in
1820
# fastapi-contrib
1921
fastapi-contrib==0.2.11
2022
# via -r requirements/_fastapi.in
23+
h11==0.13.0
24+
# via uvicorn
2125
idna==3.3
2226
# via
2327
# anyio
@@ -59,3 +63,5 @@ urllib3==1.26.9
5963
# via
6064
# -c requirements/./../../../requirements/constraints.txt
6165
# requests
66+
uvicorn==0.18.2
67+
# via -r requirements/_fastapi.in

packages/service-library/setup.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,6 @@ universal = 1
1212

1313
[aliases]
1414
test = pytest
15+
16+
[tool:pytest]
17+
asyncio_mode = auto
Lines changed: 99 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,110 +1,124 @@
11
import asyncio
22
import inspect
33
import logging
4-
from asyncio import CancelledError
5-
from contextlib import suppress
64
from functools import wraps
7-
from typing import Any, Callable, Coroutine, Optional
5+
from typing import Any, Protocol
86

9-
from fastapi import Request, Response
7+
from fastapi import Request, status
8+
from fastapi.exceptions import HTTPException
109

1110
logger = logging.getLogger(__name__)
1211

1312

14-
_DEFAULT_CHECK_INTERVAL_S: float = 0.5
13+
class _HandlerWithRequestArg(Protocol):
14+
__name__: str
1515

16-
HTTP_499_CLIENT_CLOSED_REQUEST = 499
17-
# A non-standard status code introduced by nginx for the case when a client
18-
# closes the connection while nginx is processing the request.
19-
# SEE https://www.webfx.com/web-development/glossary/http-status-codes/what-is-a-499-status-code/
16+
async def __call__(self, request: Request, *args: Any) -> Any:
17+
...
2018

21-
TASK_NAME_PREFIX = "cancellable_request"
2219

23-
_FastAPIHandlerCallable = Callable[..., Coroutine[Any, Any, Optional[Any]]]
20+
def _validate_signature(handler: _HandlerWithRequestArg):
21+
"""Raises ValueError if handler does not have expected signature"""
22+
try:
23+
p = next(iter(inspect.signature(handler).parameters.values()))
24+
if p.kind != inspect.Parameter.POSITIONAL_OR_KEYWORD or p.annotation != Request:
25+
raise TypeError(
26+
f"Invalid handler {handler.__name__} signature: first parameter must be a Request, got {p.annotation}"
27+
)
28+
except StopIteration as e:
29+
raise TypeError(
30+
f"Invalid handler {handler.__name__} signature: first parameter must be a Request, got none"
31+
) from e
2432

2533

26-
async def _cancel_task_if_client_disconnected(
27-
request: Request, task: asyncio.Task, interval: float = _DEFAULT_CHECK_INTERVAL_S
28-
) -> None:
29-
try:
30-
while True:
31-
if task.done():
32-
logger.debug("task %s is done", task)
33-
break
34-
if await request.is_disconnected():
35-
logger.warning(
36-
"client %s disconnected! Cancelling handler for %s",
37-
request.client,
38-
f"{request.url=}",
39-
)
40-
task.cancel()
41-
break
42-
await asyncio.sleep(interval)
43-
except CancelledError:
44-
logger.debug("task monitoring %s handler was cancelled", f"{request.url=}")
45-
raise
46-
finally:
47-
logger.debug("task monitoring %s handler completed", f"{request.url}")
48-
49-
50-
def cancellable_request(handler_fun: _FastAPIHandlerCallable):
51-
"""This decorator periodically checks if the client disconnected and
52-
then will cancel the request and return a HTTP_499_CLIENT_CLOSED_REQUEST code (a la nginx).
53-
54-
Usage: decorate the cancellable route and add request: Request as an argument
55-
56-
@cancellable_request
57-
async def route(
58-
_request: Request,
59-
...
60-
)
34+
#
35+
# cancel_on_disconnect/disconnect_poller based
36+
# on https://github.com/RedRoserade/fastapi-disconnect-example/blob/main/app.py
37+
#
38+
_POLL_INTERVAL_S: float = 0.01
39+
40+
41+
async def disconnect_poller(request: Request, result: Any):
6142
"""
62-
# CHECK: Early check that will raise upon import
63-
# IMPROVEMENT: inject this parameter to handler_fun here before it returned in the wrapper and consumed by fastapi.router?
64-
found_required_arg = any(
65-
parameter.name == "_request" and parameter.annotation == Request
66-
for parameter in inspect.signature(handler_fun).parameters.values()
43+
Poll for a disconnect.
44+
If the request disconnects, stop polling and return.
45+
"""
46+
while not await request.is_disconnected():
47+
await asyncio.sleep(_POLL_INTERVAL_S)
48+
49+
logger.debug(
50+
"client %s disconnected! Cancelling handler for request %s %s",
51+
request.client,
52+
request.method,
53+
request.url,
6754
)
68-
if not found_required_arg:
69-
raise ValueError(
70-
f"Invalid handler {handler_fun.__name__} signature: missing required parameter _request: Request"
71-
)
55+
return result
56+
7257

73-
# WRAPPER ----
74-
@wraps(handler_fun)
75-
async def wrapper(*args, **kwargs) -> Optional[Any]:
76-
request: Request = kwargs["_request"]
58+
def cancel_on_disconnect(handler: _HandlerWithRequestArg):
59+
"""
60+
After client dicsonnects, handler gets cancelled in ~<3 secs
61+
"""
7762

78-
# Intercepts handler call and creates a task out of it
63+
_validate_signature(handler)
64+
65+
@wraps(handler)
66+
async def wrapper(request: Request, *args, **kwargs):
67+
sentinel = object()
68+
69+
# Create two tasks:
70+
# one to poll the request and check if the client disconnected
71+
poller_task = asyncio.create_task(
72+
disconnect_poller(request, sentinel),
73+
name=f"cancel_on_disconnect/poller/{handler.__name__}/{id(sentinel)}",
74+
)
75+
# , and another which is the request handler
7976
handler_task = asyncio.create_task(
80-
handler_fun(*args, **kwargs),
81-
name=f"{TASK_NAME_PREFIX}/handler/{handler_fun.__name__}",
77+
handler(request, *args, **kwargs),
78+
name=f"cancel_on_disconnect/handler/{handler.__name__}/{id(sentinel)}",
8279
)
83-
# An extra task to monitor when the client disconnects so it can
84-
# cancel 'handler_task'
85-
auto_cancel_task = asyncio.create_task(
86-
_cancel_task_if_client_disconnected(request, handler_task),
87-
name=f"{TASK_NAME_PREFIX}/auto_cancel/{handler_fun.__name__}",
80+
81+
done, pending = await asyncio.wait(
82+
[poller_task, handler_task], return_when=asyncio.FIRST_COMPLETED
8883
)
8984

90-
try:
85+
# One has completed, cancel the other
86+
for t in pending:
87+
t.cancel()
88+
try:
89+
await asyncio.wait_for(t, timeout=3)
90+
except asyncio.CancelledError:
91+
logger.debug("%s was cancelled", t)
92+
except Exception as exc: # pylint: disable=broad-except
93+
if t is handler_task:
94+
logger.warning(
95+
"%s raised %s when being cancelled.", t, exc, exc_info=True
96+
)
97+
raise
98+
finally:
99+
assert t.done() # nosec
100+
101+
# Return the result if the handler finished first
102+
if handler_task in done:
103+
assert poller_task.done() # nosec
91104
return await handler_task
92-
except CancelledError:
93-
# TODO: check that 'auto_cancel_task' actually executed this cancellation
94-
# E.g. app shutdown might cancel all pending tasks
95-
logger.warning(
96-
"Request %s was cancelled since client %s disconnected !",
97-
f"{request.url}",
98-
request.client,
99-
)
100-
return Response(
101-
"Request cancelled because client disconnected",
102-
status_code=HTTP_499_CLIENT_CLOSED_REQUEST,
103-
)
104-
finally:
105-
# NOTE: This is ALSO called 'await handler_task' returns
106-
auto_cancel_task.cancel()
107-
with suppress(CancelledError):
108-
await auto_cancel_task
105+
106+
# Otherwise, raise an exception. This is not exactly needed, but it will prevent
107+
# validation errors if your request handler is supposed to return something.
108+
logger.debug(
109+
"Request %s %s cancelled:\n - %s\n - %s",
110+
request.method,
111+
request.url,
112+
f"{poller_task=}",
113+
f"{handler_task=}",
114+
)
115+
assert poller_task.done() # nosec
116+
assert handler_task.done() # nosec
117+
118+
# NOTE: uvicorn server fails with 499
119+
raise HTTPException(
120+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
121+
detail=f"client disconnected from {request=}",
122+
)
109123

110124
return wrapper

0 commit comments

Comments
 (0)