Skip to content

Commit cc82682

Browse files
committed
mypy
1 parent 9548404 commit cc82682

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

packages/service-library/src/servicelib/fastapi/cancellation_middleware.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import asyncio
22
import logging
3+
from typing import NoReturn
34

4-
from servicelib.logging_utils import log_context
55
from starlette.requests import Request
6-
from starlette.types import ASGIApp, Receive, Scope, Send
6+
from starlette.types import ASGIApp, Message, Receive, Scope, Send
7+
8+
from .logging_utils import log_context
79

810
_logger = logging.getLogger(__name__)
911

@@ -12,7 +14,9 @@ class _TerminateTaskGroupError(Exception):
1214
pass
1315

1416

15-
async def _message_poller(request: Request, queue: asyncio.Queue, receive: Receive):
17+
async def _message_poller(
18+
request: Request, queue: asyncio.Queue, receive: Receive
19+
) -> NoReturn:
1620
while True:
1721
message = await receive()
1822
if message["type"] == "http.disconnect":
@@ -23,7 +27,9 @@ async def _message_poller(request: Request, queue: asyncio.Queue, receive: Recei
2327
await queue.put(message)
2428

2529

26-
async def _handler(app: ASGIApp, scope: Scope, queue: asyncio.Queue, send: Send):
30+
async def _handler(
31+
app: ASGIApp, scope: Scope, queue: asyncio.Queue[Message], send: Send
32+
) -> None:
2733
return await app(scope, queue.get, send)
2834

2935

@@ -47,10 +53,10 @@ def __init__(self, app: ASGIApp) -> None:
4753
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
4854
if scope["type"] != "http":
4955
await self.app(scope, receive, send)
50-
return None
56+
return
5157

5258
# Let's make a shared queue for the request messages
53-
queue = asyncio.Queue()
59+
queue: asyncio.Queue[Message] = asyncio.Queue()
5460

5561
request = Request(scope)
5662

@@ -63,9 +69,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
6369
poller_task = tg.create_task(
6470
_message_poller(request, queue, receive)
6571
)
66-
response = await handler_task
72+
await handler_task
6773
poller_task.cancel()
68-
return response
6974
except* _TerminateTaskGroupError:
7075
_logger.info(
7176
"The client disconnected. request to %s was cancelled.", request.url

0 commit comments

Comments
 (0)