Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
958f64f
readd cancel on disconnect decorator
bisgaard-itis Jun 30, 2025
176775f
add asyncio.Event for killing poller task
bisgaard-itis Jun 30, 2025
c376bab
add comment
bisgaard-itis Jun 30, 2025
8812a8a
add test for request handler `cancel_on_disconnect` @pcrespov
bisgaard-itis Jul 1, 2025
a80f4e1
use taskgroup for error handling
bisgaard-itis Jul 1, 2025
60e99d0
add old names to tasks
bisgaard-itis Jul 1, 2025
e6c42f7
use event to kill poller task to be sure it terminates
bisgaard-itis Jul 1, 2025
7311ef1
Merge branch 'master' into fix-hang-in-poller-task
bisgaard-itis Jul 1, 2025
e7c97de
Merge branch 'master' into fix-hang-in-poller-task
bisgaard-itis Jul 2, 2025
1229fa3
pylint
bisgaard-itis Jul 2, 2025
4cbc9dc
fix test
bisgaard-itis Jul 2, 2025
df8d9c7
Merge branch 'master' into fix-hang-in-poller-task
bisgaard-itis Jul 2, 2025
09a0f85
Merge branch 'master' into fix-hang-in-poller-task
bisgaard-itis Jul 2, 2025
321fa62
Merge branch 'master' into fix-hang-in-poller-task
bisgaard-itis Jul 4, 2025
e22808a
@pcrespov readd comments
bisgaard-itis Jul 4, 2025
5399294
readd old tests
bisgaard-itis Jul 4, 2025
97afe88
readd
bisgaard-itis Jul 4, 2025
373ad5c
factor out core funcionality into run_until_cancelled
bisgaard-itis Jul 4, 2025
ef3911e
fix tests
bisgaard-itis Jul 4, 2025
a626a29
ensure function decorator works on local deployment
bisgaard-itis Jul 4, 2025
edd826e
migrate middleware to new implementation
bisgaard-itis Jul 4, 2025
241a0e9
improve types
bisgaard-itis Jul 4, 2025
1c091a1
Merge branch 'master' into fix-hang-in-poller-task
bisgaard-itis Jul 4, 2025
431df54
fix request cancellation test
bisgaard-itis Jul 7, 2025
fe5b4ca
improve naming
bisgaard-itis Jul 7, 2025
62808a8
Revert "improve naming"
bisgaard-itis Jul 7, 2025
8c2e369
Merge branch 'master' into fix-hang-in-poller-task
bisgaard-itis Jul 7, 2025
6f6d755
Merge branch 'master' into fix-hang-in-poller-task
bisgaard-itis Jul 16, 2025
f0cfe6b
Merge branch 'master' into fix-hang-in-poller-task
bisgaard-itis Jul 17, 2025
b8e20d2
Merge branch 'master' into fix-hang-in-poller-task
bisgaard-itis Aug 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 65 additions & 1 deletion packages/service-library/src/servicelib/async_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import logging
from collections import deque
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Coroutine
from contextlib import suppress
from dataclasses import dataclass
from functools import wraps
Expand Down Expand Up @@ -210,3 +210,67 @@ async def worker(in_q: Queue[QueueElement], out_q: Queue) -> None:
return wrapper

return decorator


TaskCancelCallback = Callable[[], Awaitable[bool]]
_POLL_INTERVAL_S: float = 0.01


class TaskCancelled(Exception):
"""Internal exception raised by the poller task when the client disconnects."""


async def _poller_for_task_group(
close_event: asyncio.Event,
cancel_awaitable: TaskCancelCallback,
poll_interval: float,
):
"""
Polls for cancellation via the callback and raises TaskCancelled if it occurs.
"""
while not await cancel_awaitable():
await asyncio.sleep(poll_interval)
if close_event.is_set():
return
raise TaskCancelled


async def run_until_cancelled(
*,
coro: Coroutine,
cancel_callback: TaskCancelCallback,
poll_interval: float = _POLL_INTERVAL_S,
) -> Any:
"""
Runs the given coroutine until it completes or cancellation is requested.

This function executes the provided coroutine and periodically checks the given
cancel_callback. If cancellation is requested (i.e., cancel_callback returns True),
the coroutine is cancelled and a TaskCancelled exception is raised. If the coroutine
completes first, its result (or exception) is returned/reraised.
"""
sentinel = object()
close_poller_event = asyncio.Event()
try:
# Create two tasks in a TaskGroup
async with asyncio.TaskGroup() as tg:

# One to poll for cancellation
tg.create_task(
_poller_for_task_group(
close_poller_event, cancel_callback, poll_interval
),
name=f"run_until_cancelled/poller/{coro.__name__}/{id(sentinel)}",
)
# The other to run the actual coroutine
coro_task = tg.create_task(
coro,
name=f"run_until_cancelled/coroutine/{coro.__name__}/{id(sentinel)}",
)
await coro_task
close_poller_event.set()

return coro_task.result()

except* Exception as eg:
raise eg.exceptions[0] # pylint: disable=no-member
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
import logging
from typing import NoReturn
from functools import partial

from servicelib.async_utils import TaskCancelled, run_until_cancelled
from starlette.requests import Request
from starlette.types import ASGIApp, Message, Receive, Scope, Send

Expand All @@ -10,29 +11,23 @@
_logger = logging.getLogger(__name__)


class _ClientDisconnectedError(Exception):
pass


async def _message_poller(
request: Request, queue: asyncio.Queue, receive: Receive
) -> NoReturn:
while True:
message = await receive()
if message["type"] == "http.disconnect":
_logger.debug(
"client disconnected the request to %s!", request.url, stacklevel=2
)
raise _ClientDisconnectedError

# Puts the message in the queue
await queue.put(message)


async def _handler(
app: ASGIApp, scope: Scope, queue: asyncio.Queue[Message], send: Send
) -> None:
return await app(scope, queue.get, send)
await app(scope, queue.get, send)


async def _is_client_disconnected(
receive: Receive, queue: asyncio.Queue[Message], request: Request
) -> bool:
message = await receive()
if message["type"] == "http.disconnect":
_logger.debug("client disconnected, terminating request to %s!", request.url)
return True

# Puts the message in the queue
await queue.put(message)
return False


class RequestCancellationMiddleware:
Expand All @@ -58,23 +53,22 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
return

# Let's make a shared queue for the request messages
queue: asyncio.Queue[Message] = asyncio.Queue()
request = Request(scope)
queue: asyncio.Queue[Message] = asyncio.Queue()

with log_context(_logger, logging.DEBUG, f"cancellable request {request.url}"):
try:
async with asyncio.TaskGroup() as tg:
handler_task = tg.create_task(
_handler(self.app, scope, queue, send)
)
poller_task = tg.create_task(
_message_poller(request, queue, receive)
)
await handler_task
poller_task.cancel()
except* _ClientDisconnectedError:
if not handler_task.done():
_logger.info(
"The client disconnected. The request to %s was cancelled.",
request.url,
)
await run_until_cancelled(
coro=_handler(self.app, scope, queue, send),
cancel_callback=partial(
_is_client_disconnected, receive, queue, request
),
poll_interval=0.0,
)
return

except TaskCancelled:
_logger.info(
"The client disconnected. request to %s was cancelled.",
request.url,
)
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import asyncio
import inspect
import logging
from functools import wraps
from functools import partial, wraps
from typing import Any, Protocol

from fastapi import Request, status
from fastapi.exceptions import HTTPException

from ..async_utils import TaskCancelled, run_until_cancelled

logger = logging.getLogger(__name__)


class _HandlerWithRequestArg(Protocol):
__name__: str

async def __call__(self, request: Request, *args: Any, **kwargs: Any) -> Any:
...
async def __call__(self, request: Request, *args: Any, **kwargs: Any) -> Any: ...


def _validate_signature(handler: _HandlerWithRequestArg):
Expand All @@ -30,89 +30,46 @@ def _validate_signature(handler: _HandlerWithRequestArg):


#
# cancel_on_disconnect/disconnect_poller based
# on https://github.com/RedRoserade/fastapi-disconnect-example/blob/main/app.py
# cancel_on_disconnect based on TaskGroup
#
_POLL_INTERVAL_S: float = 0.01


async def _disconnect_poller(request: Request, result: Any):
"""
Poll for a disconnect.
If the request disconnects, stop polling and return.
"""
while not await request.is_disconnected():
await asyncio.sleep(_POLL_INTERVAL_S)
return result
async def _is_client_disconnected(request: Request):
return await request.is_disconnected()


def cancel_on_disconnect(handler: _HandlerWithRequestArg):
"""
After client disconnects, handler gets cancelled in ~<3 secs
Decorator that cancels the request handler if the client disconnects.

Uses a TaskGroup to manage the handler and a poller task concurrently.
If the client disconnects, the poller raises an exception, which is
caught and translated into a 503 Service Unavailable response.
"""

_validate_signature(handler)

@wraps(handler)
async def wrapper(request: Request, *args, **kwargs):
sentinel = object()

# Create two tasks:
# one to poll the request and check if the client disconnected
poller_task = asyncio.create_task(
_disconnect_poller(request, sentinel),
name=f"cancel_on_disconnect/poller/{handler.__name__}/{id(sentinel)}",
)
# , and another which is the request handler
handler_task = asyncio.create_task(
handler(request, *args, **kwargs),
name=f"cancel_on_disconnect/handler/{handler.__name__}/{id(sentinel)}",
)

done, pending = await asyncio.wait(
[poller_task, handler_task], return_when=asyncio.FIRST_COMPLETED
)

# One has completed, cancel the other
for t in pending:
t.cancel()

try:
await asyncio.wait_for(t, timeout=3)

except asyncio.CancelledError:
pass
except Exception: # pylint: disable=broad-except
if t is handler_task:
raise
finally:
assert t.done() # nosec

# Return the result if the handler finished first
if handler_task in done:
assert poller_task.done() # nosec
return await handler_task

# Otherwise, raise an exception. This is not exactly needed,
# but it will prevent validation errors if your request handler
# is supposed to return something.
logger.warning(
"Request %s %s cancelled since client %s disconnected:\n - %s\n - %s",
request.method,
request.url,
request.client,
f"{poller_task=}",
f"{handler_task=}",
)

assert poller_task.done() # nosec
assert handler_task.done() # nosec

# NOTE: uvicorn server fails with 499
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"client disconnected from {request=}",
)

try:
return await run_until_cancelled(
coro=handler(request, *args, **kwargs),
cancel_callback=partial(_is_client_disconnected, request),
)

except TaskCancelled as exc:

logger.info(
"Request %s %s cancelled since client %s disconnected.",
request.method,
request.url,
request.client,
)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Client disconnected",
) from exc

return wrapper

Expand Down
Loading
Loading