Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
958f64f
readd cancel on disconnect decorator
bisgaard-itis Jun 30, 2025
176775f
add asyncio.Event for killing poller task
bisgaard-itis Jun 30, 2025
c376bab
add comment
bisgaard-itis Jun 30, 2025
8812a8a
add test for request handler `cancel_on_disconnect` @pcrespov
bisgaard-itis Jul 1, 2025
a80f4e1
use taskgroup for error handling
bisgaard-itis Jul 1, 2025
60e99d0
add old names to tasks
bisgaard-itis Jul 1, 2025
e6c42f7
use event to kill poller task to be sure it terminates
bisgaard-itis Jul 1, 2025
7311ef1
Merge branch 'master' into fix-hang-in-poller-task
bisgaard-itis Jul 1, 2025
e7c97de
Merge branch 'master' into fix-hang-in-poller-task
bisgaard-itis Jul 2, 2025
1229fa3
pylint
bisgaard-itis Jul 2, 2025
4cbc9dc
fix test
bisgaard-itis Jul 2, 2025
df8d9c7
Merge branch 'master' into fix-hang-in-poller-task
bisgaard-itis Jul 2, 2025
09a0f85
Merge branch 'master' into fix-hang-in-poller-task
bisgaard-itis Jul 2, 2025
321fa62
Merge branch 'master' into fix-hang-in-poller-task
bisgaard-itis Jul 4, 2025
e22808a
@pcrespov readd comments
bisgaard-itis Jul 4, 2025
5399294
readd old tests
bisgaard-itis Jul 4, 2025
97afe88
readd
bisgaard-itis Jul 4, 2025
373ad5c
factor out core funcionality into run_until_cancelled
bisgaard-itis Jul 4, 2025
ef3911e
fix tests
bisgaard-itis Jul 4, 2025
a626a29
ensure function decorator works on local deployment
bisgaard-itis Jul 4, 2025
edd826e
migrate middleware to new implementation
bisgaard-itis Jul 4, 2025
241a0e9
improve types
bisgaard-itis Jul 4, 2025
1c091a1
Merge branch 'master' into fix-hang-in-poller-task
bisgaard-itis Jul 4, 2025
431df54
fix request cancellation test
bisgaard-itis Jul 7, 2025
fe5b4ca
improve naming
bisgaard-itis Jul 7, 2025
62808a8
Revert "improve naming"
bisgaard-itis Jul 7, 2025
8c2e369
Merge branch 'master' into fix-hang-in-poller-task
bisgaard-itis Jul 7, 2025
6f6d755
Merge branch 'master' into fix-hang-in-poller-task
bisgaard-itis Jul 16, 2025
f0cfe6b
Merge branch 'master' into fix-hang-in-poller-task
bisgaard-itis Jul 17, 2025
b8e20d2
Merge branch 'master' into fix-hang-in-poller-task
bisgaard-itis Aug 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
class _HandlerWithRequestArg(Protocol):
__name__: str

async def __call__(self, request: Request, *args: Any, **kwargs: Any) -> Any:
...
async def __call__(self, request: Request, *args: Any, **kwargs: Any) -> Any: ...


def _validate_signature(handler: _HandlerWithRequestArg):
Expand All @@ -30,89 +29,73 @@ def _validate_signature(handler: _HandlerWithRequestArg):


#
# cancel_on_disconnect/disconnect_poller based
# on https://github.com/RedRoserade/fastapi-disconnect-example/blob/main/app.py
# cancel_on_disconnect based on TaskGroup
#
_POLL_INTERVAL_S: float = 0.01


async def _disconnect_poller(request: Request, result: Any):
class _ClientDisconnectedError(Exception):
"""Internal exception raised by the poller task when the client disconnects."""


async def _disconnect_poller_for_task_group(
close_event: asyncio.Event, request: Request
):
"""
Poll for a disconnect.
If the request disconnects, stop polling and return.
Polls for client disconnection and raises _ClientDisconnectedError if it occurs.
"""
while not await request.is_disconnected():
await asyncio.sleep(_POLL_INTERVAL_S)
return result
if close_event.is_set():
return
raise _ClientDisconnectedError()


def cancel_on_disconnect(handler: _HandlerWithRequestArg):
"""
After client disconnects, handler gets cancelled in ~<3 secs
Decorator that cancels the request handler if the client disconnects.

Uses a TaskGroup to manage the handler and a poller task concurrently.
If the client disconnects, the poller raises an exception, which is
caught and translated into a 503 Service Unavailable response.
"""

_validate_signature(handler)

@wraps(handler)
async def wrapper(request: Request, *args, **kwargs):
sentinel = object()

# Create two tasks:
# one to poll the request and check if the client disconnected
poller_task = asyncio.create_task(
_disconnect_poller(request, sentinel),
name=f"cancel_on_disconnect/poller/{handler.__name__}/{id(sentinel)}",
)
# , and another which is the request handler
handler_task = asyncio.create_task(
handler(request, *args, **kwargs),
name=f"cancel_on_disconnect/handler/{handler.__name__}/{id(sentinel)}",
)

done, pending = await asyncio.wait(
[poller_task, handler_task], return_when=asyncio.FIRST_COMPLETED
)

# One has completed, cancel the other
for t in pending:
t.cancel()

try:
await asyncio.wait_for(t, timeout=3)

except asyncio.CancelledError:
pass
except Exception: # pylint: disable=broad-except
if t is handler_task:
raise
finally:
assert t.done() # nosec

# Return the result if the handler finished first
if handler_task in done:
assert poller_task.done() # nosec
return await handler_task

# Otherwise, raise an exception. This is not exactly needed,
# but it will prevent validation errors if your request handler
# is supposed to return something.
logger.warning(
"Request %s %s cancelled since client %s disconnected:\n - %s\n - %s",
request.method,
request.url,
request.client,
f"{poller_task=}",
f"{handler_task=}",
)

assert poller_task.done() # nosec
assert handler_task.done() # nosec

# NOTE: uvicorn server fails with 499
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"client disconnected from {request=}",
)
kill_poller_task_event = asyncio.Event()
try:
async with asyncio.TaskGroup() as tg:

tg.create_task(
_disconnect_poller_for_task_group(kill_poller_task_event, request),
name=f"cancel_on_disconnect/poller/{handler.__name__}/{id(sentinel)}",
)
handler_task = tg.create_task(
handler(request, *args, **kwargs),
name=f"cancel_on_disconnect/handler/{handler.__name__}/{id(sentinel)}",
)
await handler_task
kill_poller_task_event.set()

return handler_task.result()

except* _ClientDisconnectedError as eg:
logger.info(
"Request %s %s cancelled since client %s disconnected.",
request.method,
request.url,
request.client,
)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Client disconnected",
) from eg

except* Exception as eg:
raise eg.exceptions[0] # pylint: disable=no-member

return wrapper

Expand Down
162 changes: 55 additions & 107 deletions packages/service-library/tests/fastapi/test_request_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,121 +2,69 @@
# pylint: disable=unused-argument
# pylint: disable=unused-variable


import asyncio
import subprocess
import sys
import time
from collections.abc import Callable, Iterator
from contextlib import contextmanager
from pathlib import Path
from typing import NamedTuple
from collections.abc import Awaitable, Callable
from unittest.mock import AsyncMock

import pytest
import requests
from fastapi import FastAPI, Query, Request
from fastapi import Request
from servicelib.fastapi.requests_decorators import cancel_on_disconnect

CURRENT_FILE = Path(sys.argv[0] if __name__ == "__main__" else __file__).resolve()
CURRENT_DIR = CURRENT_FILE.parent
POLLER_CLEANUP_DELAY_S = 100.0


@pytest.fixture
def long_running_poller_mock(
monkeypatch: pytest.MonkeyPatch,
) -> Callable[[asyncio.Event, Request], Awaitable]:

mock_app = FastAPI(title="Disconnect example")
async def _mock_disconnect_poller(close_event: asyncio.Event, request: Request):
_mock_disconnect_poller.called = True
while not await request.is_disconnected():
await asyncio.sleep(2)
if close_event.is_set():
break

MESSAGE_ON_HANDLER_CANCELLATION = "Request was cancelled!!"
monkeypatch.setattr(
"servicelib.fastapi.requests_decorators._disconnect_poller_for_task_group",
_mock_disconnect_poller,
)
return _mock_disconnect_poller


@mock_app.get("/example")
@cancel_on_disconnect
async def example(
request: Request,
wait: float = Query(..., description="Time to wait, in seconds"),
async def test_decorator_waits_for_poller_cleanup(
long_running_poller_mock: Callable[[asyncio.Event, Request], Awaitable],
):
try:
print(f"Sleeping for {wait:.2f}")
await asyncio.sleep(wait)
print("Sleep not cancelled")
return f"I waited for {wait:.2f}s and now this is the result"
except asyncio.CancelledError:
print(MESSAGE_ON_HANDLER_CANCELLATION)
raise


class ServerInfo(NamedTuple):
url: str
proc: subprocess.Popen


@contextmanager
def server_lifetime(port: int) -> Iterator[ServerInfo]:
with subprocess.Popen(
[
"uvicorn",
f"{CURRENT_FILE.stem}:mock_app",
"--port",
f"{port}",
],
cwd=f"{CURRENT_DIR}",
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
) as proc:

url = f"http://127.0.0.1:{port}"
print("\nStarted", proc.args)

# some time to start
time.sleep(2)

# checks started successfully
assert proc.stdout
assert not proc.poll(), proc.stdout.read().decode("utf-8")
print("server is up and waiting for requests...")
yield ServerInfo(url, proc)
print("server is closing...")
proc.terminate()
print("server terminated")


def test_cancel_on_disconnect(get_unused_port: Callable[[], int]):

with server_lifetime(port=get_unused_port()) as server:
url, proc = server
print("--> testing server")
response = requests.get(f"{server.url}/example?wait=0", timeout=2)
print(response.url, "->", response.text)
response.raise_for_status()
print("<-- server responds")

print("--> testing server correctly cancels")
with pytest.raises(requests.exceptions.ReadTimeout):
response = requests.get(f"{server.url}/example?wait=2", timeout=0.5)
print("<-- testing server correctly cancels done")

print("--> testing server again")
# NOTE: the timeout here appears to be sensitive. if it is set <5 the test hangs from time to time
response = requests.get(f"{server.url}/example?wait=1", timeout=5)
print(response.url, "->", response.text)
response.raise_for_status()
print("<-- testing server again done")

# kill service
server.proc.terminate()
assert server.proc.stdout
server_log = server.proc.stdout.read().decode("utf-8")
print(
f"{server.url=} stdout",
"-" * 10,
"\n",
server_log,
"-" * 30,
)
# server.url=http://127.0.0.1:44077 stdout ----------
# Sleeping for 0.00
# Sleep not cancelled
# INFO: 127.0.0.1:35114 - "GET /example?wait=0 HTTP/1.1" 200 OK
# Sleeping for 2.00
# Exiting on cancellation
# Sleeping for 1.00
# Sleep not cancelled
# INFO: 127.0.0.1:35134 - "GET /example?wait=1 HTTP/1.1" 200 OK

assert MESSAGE_ON_HANDLER_CANCELLATION in server_log
"""
Tests that the decorator's wrapper waits for the poller task to finish
its cleanup, even if the handler finishes first, without needing a full server.
"""
long_running_poller_mock.called = False
handler_was_called = False

@cancel_on_disconnect
async def my_handler(request: Request):
nonlocal handler_was_called
handler_was_called = True
await asyncio.sleep(0.1) # Simulate quick work
return "Success"

# Mock a fastapi.Request object
mock_request = AsyncMock(spec=Request)
mock_request.is_disconnected.return_value = False

# ---
tasks_before = asyncio.all_tasks()

# Call the decorated handler
_ = await my_handler(mock_request)

tasks_after = asyncio.all_tasks()
# ---

assert handler_was_called
assert long_running_poller_mock.called == True

# Check that no background tasks were left orphaned
assert tasks_before == tasks_after
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ async def upload_files(files: list[UploadFile] = FileParam(...)):
response_model=ClientFileUploadData,
responses=_FILE_STATUS_CODES,
)
@cancel_on_disconnect
async def get_upload_links(
request: Request,
client_file: UserFileToProgramJob | UserFile,
Expand Down Expand Up @@ -421,6 +422,7 @@ async def abort_multipart_upload(
response_model=OutputFile,
responses=_FILE_STATUS_CODES,
)
@cancel_on_disconnect
async def complete_multipart_upload(
request: Request,
file_id: UUID,
Expand Down
Loading