diff --git a/packages/service-library/src/servicelib/async_utils.py b/packages/service-library/src/servicelib/async_utils.py index d2c62ba55ffb..c8892fba010a 100644 --- a/packages/service-library/src/servicelib/async_utils.py +++ b/packages/service-library/src/servicelib/async_utils.py @@ -1,7 +1,7 @@ import asyncio import logging from collections import deque -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Coroutine from contextlib import suppress from dataclasses import dataclass from functools import wraps @@ -210,3 +210,67 @@ async def worker(in_q: Queue[QueueElement], out_q: Queue) -> None: return wrapper return decorator + + +TaskCancelCallback = Callable[[], Awaitable[bool]] +_POLL_INTERVAL_S: float = 0.01 + + +class TaskCancelled(Exception): + """Internal exception raised by the poller task when the client disconnects.""" + + +async def _poller_for_task_group( + close_event: asyncio.Event, + cancel_awaitable: TaskCancelCallback, + poll_interval: float, +): + """ + Polls for cancellation via the callback and raises TaskCancelled if it occurs. + """ + while not await cancel_awaitable(): + await asyncio.sleep(poll_interval) + if close_event.is_set(): + return + raise TaskCancelled + + +async def run_until_cancelled( + *, + coro: Coroutine, + cancel_callback: TaskCancelCallback, + poll_interval: float = _POLL_INTERVAL_S, +) -> Any: + """ + Runs the given coroutine until it completes or cancellation is requested. + + This function executes the provided coroutine and periodically checks the given + cancel_callback. If cancellation is requested (i.e., cancel_callback returns True), + the coroutine is cancelled and a TaskCancelled exception is raised. If the coroutine + completes first, its result (or exception) is returned/reraised. + """ + sentinel = object() + close_poller_event = asyncio.Event() + try: + # Create two tasks in a TaskGroup + async with asyncio.TaskGroup() as tg: + + # One to poll for cancellation + tg.create_task( + _poller_for_task_group( + close_poller_event, cancel_callback, poll_interval + ), + name=f"run_until_cancelled/poller/{coro.__name__}/{id(sentinel)}", + ) + # The other to run the actual coroutine + coro_task = tg.create_task( + coro, + name=f"run_until_cancelled/coroutine/{coro.__name__}/{id(sentinel)}", + ) + await coro_task + close_poller_event.set() + + return coro_task.result() + + except* Exception as eg: + raise eg.exceptions[0] # pylint: disable=no-member diff --git a/packages/service-library/src/servicelib/fastapi/cancellation_middleware.py b/packages/service-library/src/servicelib/fastapi/cancellation_middleware.py index 0f7003137bbe..3d69c72d8ab1 100644 --- a/packages/service-library/src/servicelib/fastapi/cancellation_middleware.py +++ b/packages/service-library/src/servicelib/fastapi/cancellation_middleware.py @@ -1,7 +1,8 @@ import asyncio import logging -from typing import NoReturn +from functools import partial +from servicelib.async_utils import TaskCancelled, run_until_cancelled from starlette.requests import Request from starlette.types import ASGIApp, Message, Receive, Scope, Send @@ -10,29 +11,23 @@ _logger = logging.getLogger(__name__) -class _ClientDisconnectedError(Exception): - pass - - -async def _message_poller( - request: Request, queue: asyncio.Queue, receive: Receive -) -> NoReturn: - while True: - message = await receive() - if message["type"] == "http.disconnect": - _logger.debug( - "client disconnected the request to %s!", request.url, stacklevel=2 - ) - raise _ClientDisconnectedError - - # Puts the message in the queue - await queue.put(message) - - async def _handler( app: ASGIApp, scope: Scope, queue: asyncio.Queue[Message], send: Send ) -> None: - return await app(scope, queue.get, send) + await app(scope, queue.get, send) + + +async def _is_client_disconnected( + receive: Receive, queue: asyncio.Queue[Message], request: Request +) -> bool: + message = await receive() + if message["type"] == "http.disconnect": + _logger.debug("client disconnected, terminating request to %s!", request.url) + return True + + # Puts the message in the queue + await queue.put(message) + return False class RequestCancellationMiddleware: @@ -58,23 +53,22 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: return # Let's make a shared queue for the request messages - queue: asyncio.Queue[Message] = asyncio.Queue() request = Request(scope) + queue: asyncio.Queue[Message] = asyncio.Queue() with log_context(_logger, logging.DEBUG, f"cancellable request {request.url}"): try: - async with asyncio.TaskGroup() as tg: - handler_task = tg.create_task( - _handler(self.app, scope, queue, send) - ) - poller_task = tg.create_task( - _message_poller(request, queue, receive) - ) - await handler_task - poller_task.cancel() - except* _ClientDisconnectedError: - if not handler_task.done(): - _logger.info( - "The client disconnected. The request to %s was cancelled.", - request.url, - ) + await run_until_cancelled( + coro=_handler(self.app, scope, queue, send), + cancel_callback=partial( + _is_client_disconnected, receive, queue, request + ), + poll_interval=0.0, + ) + return + + except TaskCancelled: + _logger.info( + "The client disconnected. request to %s was cancelled.", + request.url, + ) diff --git a/packages/service-library/src/servicelib/fastapi/requests_decorators.py b/packages/service-library/src/servicelib/fastapi/requests_decorators.py index ae5f1ea047c6..b9c6f4941202 100644 --- a/packages/service-library/src/servicelib/fastapi/requests_decorators.py +++ b/packages/service-library/src/servicelib/fastapi/requests_decorators.py @@ -1,20 +1,20 @@ -import asyncio import inspect import logging -from functools import wraps +from functools import partial, wraps from typing import Any, Protocol from fastapi import Request, status from fastapi.exceptions import HTTPException +from ..async_utils import TaskCancelled, run_until_cancelled + logger = logging.getLogger(__name__) class _HandlerWithRequestArg(Protocol): __name__: str - async def __call__(self, request: Request, *args: Any, **kwargs: Any) -> Any: - ... + async def __call__(self, request: Request, *args: Any, **kwargs: Any) -> Any: ... def _validate_signature(handler: _HandlerWithRequestArg): @@ -30,89 +30,46 @@ def _validate_signature(handler: _HandlerWithRequestArg): # -# cancel_on_disconnect/disconnect_poller based -# on https://github.com/RedRoserade/fastapi-disconnect-example/blob/main/app.py +# cancel_on_disconnect based on TaskGroup # -_POLL_INTERVAL_S: float = 0.01 -async def _disconnect_poller(request: Request, result: Any): - """ - Poll for a disconnect. - If the request disconnects, stop polling and return. - """ - while not await request.is_disconnected(): - await asyncio.sleep(_POLL_INTERVAL_S) - return result +async def _is_client_disconnected(request: Request): + return await request.is_disconnected() def cancel_on_disconnect(handler: _HandlerWithRequestArg): """ - After client disconnects, handler gets cancelled in ~<3 secs + Decorator that cancels the request handler if the client disconnects. + + Uses a TaskGroup to manage the handler and a poller task concurrently. + If the client disconnects, the poller raises an exception, which is + caught and translated into a 503 Service Unavailable response. """ _validate_signature(handler) @wraps(handler) async def wrapper(request: Request, *args, **kwargs): - sentinel = object() - - # Create two tasks: - # one to poll the request and check if the client disconnected - poller_task = asyncio.create_task( - _disconnect_poller(request, sentinel), - name=f"cancel_on_disconnect/poller/{handler.__name__}/{id(sentinel)}", - ) - # , and another which is the request handler - handler_task = asyncio.create_task( - handler(request, *args, **kwargs), - name=f"cancel_on_disconnect/handler/{handler.__name__}/{id(sentinel)}", - ) - - done, pending = await asyncio.wait( - [poller_task, handler_task], return_when=asyncio.FIRST_COMPLETED - ) - - # One has completed, cancel the other - for t in pending: - t.cancel() - - try: - await asyncio.wait_for(t, timeout=3) - - except asyncio.CancelledError: - pass - except Exception: # pylint: disable=broad-except - if t is handler_task: - raise - finally: - assert t.done() # nosec - - # Return the result if the handler finished first - if handler_task in done: - assert poller_task.done() # nosec - return await handler_task - - # Otherwise, raise an exception. This is not exactly needed, - # but it will prevent validation errors if your request handler - # is supposed to return something. - logger.warning( - "Request %s %s cancelled since client %s disconnected:\n - %s\n - %s", - request.method, - request.url, - request.client, - f"{poller_task=}", - f"{handler_task=}", - ) - - assert poller_task.done() # nosec - assert handler_task.done() # nosec - - # NOTE: uvicorn server fails with 499 - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail=f"client disconnected from {request=}", - ) + + try: + return await run_until_cancelled( + coro=handler(request, *args, **kwargs), + cancel_callback=partial(_is_client_disconnected, request), + ) + + except TaskCancelled as exc: + + logger.info( + "Request %s %s cancelled since client %s disconnected.", + request.method, + request.url, + request.client, + ) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Client disconnected", + ) from exc return wrapper diff --git a/packages/service-library/tests/fastapi/test_request_decorators.py b/packages/service-library/tests/fastapi/test_request_decorators.py index 18f6267cf33a..968534c36383 100644 --- a/packages/service-library/tests/fastapi/test_request_decorators.py +++ b/packages/service-library/tests/fastapi/test_request_decorators.py @@ -2,20 +2,26 @@ # pylint: disable=unused-argument # pylint: disable=unused-variable + import asyncio import subprocess import sys import time -from collections.abc import Callable, Iterator +from collections.abc import Awaitable, Callable, Iterator from contextlib import contextmanager from pathlib import Path from typing import NamedTuple +from unittest.mock import AsyncMock import pytest import requests from fastapi import FastAPI, Query, Request +from servicelib.async_utils import TaskCancelCallback from servicelib.fastapi.requests_decorators import cancel_on_disconnect +POLLER_CLEANUP_DELAY_S = 100.0 + + CURRENT_FILE = Path(sys.argv[0] if __name__ == "__main__" else __file__).resolve() CURRENT_DIR = CURRENT_FILE.parent @@ -120,3 +126,65 @@ def test_cancel_on_disconnect(get_unused_port: Callable[[], int]): # INFO: 127.0.0.1:35134 - "GET /example?wait=1 HTTP/1.1" 200 OK assert MESSAGE_ON_HANDLER_CANCELLATION in server_log + + +@pytest.fixture +def long_running_poller_mock( + monkeypatch: pytest.MonkeyPatch, +) -> Callable[[asyncio.Event, TaskCancelCallback, float], Awaitable]: + + async def _mock_disconnect_poller( + close_event: asyncio.Event, + cancel_awaitable: TaskCancelCallback, + poll_interval: float, + ): + _mock_disconnect_poller.called = True + while not await cancel_awaitable(): + await asyncio.sleep(2) + if close_event.is_set(): + break + + monkeypatch.setattr( + "servicelib.async_utils._poller_for_task_group", + _mock_disconnect_poller, + ) + return _mock_disconnect_poller + + +async def test_decorator_waits_for_poller_cleanup( + long_running_poller_mock: Callable[ + [asyncio.Event, TaskCancelCallback, float], Awaitable + ], +): + """ + Tests that the decorator's wrapper waits for the poller task to finish + its cleanup, even if the handler finishes first, without needing a full server. + """ + long_running_poller_mock.called = False + handler_was_called = False + + @cancel_on_disconnect + async def my_handler(request: Request): + nonlocal handler_was_called + handler_was_called = True + await asyncio.sleep(0.1) # Simulate quick work + return "Success" + + # Mock a fastapi.Request object + mock_request = AsyncMock(spec=Request) + mock_request.is_disconnected.return_value = False + + # --- + tasks_before = asyncio.all_tasks() + + # Call the decorated handler + _ = await my_handler(mock_request) + + tasks_after = asyncio.all_tasks() + # --- + + assert handler_was_called + assert long_running_poller_mock.called == True + + # Check that no background tasks were left orphaned + assert tasks_before == tasks_after diff --git a/services/api-server/src/simcore_service_api_server/api/routes/files.py b/services/api-server/src/simcore_service_api_server/api/routes/files.py index 69f779f13e19..a220edacb232 100644 --- a/services/api-server/src/simcore_service_api_server/api/routes/files.py +++ b/services/api-server/src/simcore_service_api_server/api/routes/files.py @@ -287,6 +287,7 @@ async def upload_files(files: list[UploadFile] = FileParam(...)): response_model=ClientFileUploadData, responses=_FILE_STATUS_CODES, ) +@cancel_on_disconnect async def get_upload_links( request: Request, client_file: UserFileToProgramJob | UserFile, @@ -421,6 +422,7 @@ async def abort_multipart_upload( response_model=OutputFile, responses=_FILE_STATUS_CODES, ) +@cancel_on_disconnect async def complete_multipart_upload( request: Request, file_id: UUID,