Skip to content

Commit 813e3b6

Browse files
sandereggmrnicegyu11
authored andcommitted
Storage: Add cancellation middleware (ITISFoundation#7279)
1 parent 0acbe9b commit 813e3b6

File tree

4 files changed

+232
-4
lines changed

4 files changed

+232
-4
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ default_language_version:
55
python: python3.11
66
repos:
77
- repo: https://github.com/pre-commit/pre-commit-hooks
8-
rev: v4.2.0
8+
rev: v5.0.0
99
hooks:
1010
- id: check-added-large-files
1111
args: ["--maxkb=1024"]
@@ -22,7 +22,7 @@ repos:
2222
- id: no-commit-to-branch
2323
# NOTE: Keep order as pyupgrade (will update code) then pycln (remove unused imports), then isort (sort them) and black (final formatting)
2424
- repo: https://github.com/asottile/pyupgrade
25-
rev: v2.34.0
25+
rev: v3.19.1
2626
hooks:
2727
- id: pyupgrade
2828
args:
@@ -36,13 +36,13 @@ repos:
3636
args: [--all, --expand-stars]
3737
name: prune imports
3838
- repo: https://github.com/PyCQA/isort
39-
rev: 5.12.0
39+
rev: 6.0.0
4040
hooks:
4141
- id: isort
4242
args: ["--profile", "black"]
4343
name: sort imports
4444
- repo: https://github.com/psf/black
45-
rev: 22.3.0
45+
rev: 25.1.0
4646
hooks:
4747
- id: black
4848
name: black format code
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import asyncio
2+
import logging
3+
from typing import NoReturn
4+
5+
from starlette.requests import Request
6+
from starlette.types import ASGIApp, Message, Receive, Scope, Send
7+
8+
from ..logging_utils import log_context
9+
10+
_logger = logging.getLogger(__name__)
11+
12+
13+
class _TerminateTaskGroupError(Exception):
14+
pass
15+
16+
17+
async def _message_poller(
18+
request: Request, queue: asyncio.Queue, receive: Receive
19+
) -> NoReturn:
20+
while True:
21+
message = await receive()
22+
if message["type"] == "http.disconnect":
23+
_logger.info("client disconnected, terminating request to %s!", request.url)
24+
raise _TerminateTaskGroupError
25+
26+
# Puts the message in the queue
27+
await queue.put(message)
28+
29+
30+
async def _handler(
31+
app: ASGIApp, scope: Scope, queue: asyncio.Queue[Message], send: Send
32+
) -> None:
33+
return await app(scope, queue.get, send)
34+
35+
36+
class RequestCancellationMiddleware:
37+
"""ASGI Middleware to cancel server requests in case of client disconnection.
38+
Reason: FastAPI-based (e.g. starlette) servers do not automatically cancel
39+
server requests in case of client disconnection. This middleware will cancel
40+
the server request in case of client disconnection via asyncio.CancelledError.
41+
42+
WARNING: FastAPI BackgroundTasks will also get cancelled. Use with care.
43+
TIP: use asyncio.Task in that case
44+
"""
45+
46+
def __init__(self, app: ASGIApp) -> None:
47+
self.app = app
48+
_logger.warning(
49+
"CancellationMiddleware is in use, in case of client disconection, "
50+
"FastAPI BackgroundTasks will be cancelled too!",
51+
)
52+
53+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
54+
if scope["type"] != "http":
55+
await self.app(scope, receive, send)
56+
return
57+
58+
# Let's make a shared queue for the request messages
59+
queue: asyncio.Queue[Message] = asyncio.Queue()
60+
61+
request = Request(scope)
62+
63+
with log_context(_logger, logging.DEBUG, f"cancellable request {request.url}"):
64+
try:
65+
async with asyncio.TaskGroup() as tg:
66+
handler_task = tg.create_task(
67+
_handler(self.app, scope, queue, send)
68+
)
69+
poller_task = tg.create_task(
70+
_message_poller(request, queue, receive)
71+
)
72+
await handler_task
73+
poller_task.cancel()
74+
except* _TerminateTaskGroupError:
75+
_logger.info(
76+
"The client disconnected. request to %s was cancelled.", request.url
77+
)
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# pylint: disable=redefined-outer-name
2+
3+
import asyncio
4+
import logging
5+
from collections.abc import Iterator
6+
from threading import Thread
7+
from unittest.mock import AsyncMock
8+
9+
import httpx
10+
import pytest
11+
import uvicorn
12+
from fastapi import APIRouter, BackgroundTasks, FastAPI
13+
from pytest_simcore.helpers.logging_tools import log_context
14+
from servicelib.fastapi.cancellation_middleware import RequestCancellationMiddleware
15+
from servicelib.utils import unused_port
16+
from yarl import URL
17+
18+
19+
@pytest.fixture
20+
def server_done_event() -> asyncio.Event:
21+
return asyncio.Event()
22+
23+
24+
@pytest.fixture
25+
def server_cancelled_mock() -> AsyncMock:
26+
return AsyncMock()
27+
28+
29+
@pytest.fixture
30+
def fastapi_router(
31+
server_done_event: asyncio.Event, server_cancelled_mock: AsyncMock
32+
) -> APIRouter:
33+
router = APIRouter()
34+
35+
@router.get("/sleep")
36+
async def sleep(sleep_time: float) -> dict[str, str]:
37+
with log_context(logging.INFO, msg="sleeper") as ctx:
38+
try:
39+
await asyncio.sleep(sleep_time)
40+
return {"message": f"Slept for {sleep_time} seconds"}
41+
except asyncio.CancelledError:
42+
ctx.logger.info("sleeper cancelled!")
43+
await server_cancelled_mock()
44+
return {"message": "Cancelled"}
45+
finally:
46+
server_done_event.set()
47+
48+
async def _sleep_in_the_back(sleep_time: float) -> None:
49+
with log_context(logging.INFO, msg="sleeper in the back") as ctx:
50+
try:
51+
await asyncio.sleep(sleep_time)
52+
except asyncio.CancelledError:
53+
ctx.logger.info("sleeper in the back cancelled!")
54+
await server_cancelled_mock()
55+
finally:
56+
server_done_event.set()
57+
58+
@router.get("/sleep-with-background-task")
59+
async def sleep_with_background_task(
60+
sleep_time: float, background_tasks: BackgroundTasks
61+
) -> dict[str, str]:
62+
with log_context(logging.INFO, msg="sleeper with background task"):
63+
background_tasks.add_task(_sleep_in_the_back, sleep_time)
64+
return {"message": "Sleeping in the back"}
65+
66+
return router
67+
68+
69+
@pytest.fixture
70+
def fastapi_app(fastapi_router: APIRouter) -> FastAPI:
71+
app = FastAPI()
72+
app.include_router(fastapi_router)
73+
app.add_middleware(RequestCancellationMiddleware)
74+
return app
75+
76+
77+
@pytest.fixture
78+
def uvicorn_server(fastapi_app: FastAPI) -> Iterator[URL]:
79+
random_port = unused_port()
80+
with log_context(
81+
logging.INFO,
82+
msg=f"with uvicorn server on 127.0.0.1:{random_port}",
83+
) as ctx:
84+
config = uvicorn.Config(
85+
fastapi_app,
86+
host="127.0.0.1",
87+
port=random_port,
88+
log_level="error",
89+
)
90+
server = uvicorn.Server(config)
91+
92+
thread = Thread(target=server.run)
93+
thread.daemon = True
94+
thread.start()
95+
96+
ctx.logger.info(
97+
"server ready at: %s",
98+
f"http://127.0.0.1:{random_port}",
99+
)
100+
101+
yield URL(f"http://127.0.0.1:{random_port}")
102+
103+
server.should_exit = True
104+
thread.join(timeout=10)
105+
106+
107+
async def test_server_cancels_when_client_disconnects(
108+
uvicorn_server: URL,
109+
server_done_event: asyncio.Event,
110+
server_cancelled_mock: AsyncMock,
111+
):
112+
async with httpx.AsyncClient(base_url=f"{uvicorn_server}") as client:
113+
# check standard call still complete as expected
114+
with log_context(logging.INFO, msg="client calling endpoint"):
115+
response = await client.get("/sleep", params={"sleep_time": 0.1})
116+
assert response.status_code == 200
117+
assert response.json() == {"message": "Slept for 0.1 seconds"}
118+
async with asyncio.timeout(10):
119+
await server_done_event.wait()
120+
server_done_event.clear()
121+
122+
# check slow call get cancelled
123+
with log_context(
124+
logging.INFO, msg="client calling endpoint for cancellation"
125+
) as ctx:
126+
with pytest.raises(httpx.ReadTimeout):
127+
response = await client.get(
128+
"/sleep", params={"sleep_time": 10}, timeout=0.1
129+
)
130+
ctx.logger.info("client disconnected from server")
131+
132+
async with asyncio.timeout(5):
133+
await server_done_event.wait()
134+
server_cancelled_mock.assert_called_once()
135+
server_cancelled_mock.reset_mock()
136+
server_done_event.clear()
137+
138+
# NOTE: shows that FastAPI BackgroundTasks get cancelled too!
139+
# check background tasks get cancelled as well sadly
140+
with log_context(logging.INFO, msg="client calling endpoint for cancellation"):
141+
response = await client.get(
142+
"/sleep-with-background-task",
143+
params={"sleep_time": 2},
144+
)
145+
assert response.status_code == 200
146+
async with asyncio.timeout(5):
147+
await server_done_event.wait()
148+
server_cancelled_mock.assert_called_once()

services/storage/src/simcore_service_storage/core/application.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from fastapi.middleware.gzip import GZipMiddleware
1111
from fastapi_pagination import add_pagination
1212
from servicelib.fastapi import timing_middleware
13+
from servicelib.fastapi.cancellation_middleware import RequestCancellationMiddleware
1314
from servicelib.fastapi.client_session import setup_client_session
1415
from servicelib.fastapi.openapi import override_fastapi_openapi_method
1516
from servicelib.fastapi.profiler import ProfilerMiddleware
@@ -103,6 +104,8 @@ def create_app(settings: ApplicationSettings) -> FastAPI:
103104

104105
app.add_middleware(GZipMiddleware)
105106

107+
app.add_middleware(RequestCancellationMiddleware)
108+
106109
if settings.STORAGE_TRACING:
107110
initialize_tracing(app, settings.STORAGE_TRACING, APP_NAME)
108111
if settings.STORAGE_MONITORING_ENABLED:

0 commit comments

Comments
 (0)