Skip to content

Commit 1e911ea

Browse files
bdracoDreamsorcererpre-commit-ci[bot]
authored
Add Client Middleware Support (#9732)
Co-authored-by: Sam Bull <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent c5405bc commit 1e911ea

File tree

9 files changed

+1507
-25
lines changed

9 files changed

+1507
-25
lines changed

CHANGES/9732.feature.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Added client middleware support -- by :user:`bdraco` and :user:`Dreamsorcerer`.
2+
3+
This change allows users to add middleware to the client session and requests, enabling features like
4+
authentication, logging, and request/response modification without modifying the core
5+
request logic. Additionally, the ``session`` attribute was added to ``ClientRequest``,
6+
allowing middleware to access the session for making additional requests.

aiohttp/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
WSServerHandshakeError,
4848
request,
4949
)
50+
from .client_middlewares import ClientHandlerType, ClientMiddlewareType
5051
from .compression_utils import set_zlib_backend
5152
from .connector import AddrInfoType, SocketFactoryType
5253
from .cookiejar import CookieJar, DummyCookieJar
@@ -157,6 +158,9 @@
157158
"NamedPipeConnector",
158159
"WSServerHandshakeError",
159160
"request",
161+
# client_middleware
162+
"ClientMiddlewareType",
163+
"ClientHandlerType",
160164
# cookiejar
161165
"CookieJar",
162166
"DummyCookieJar",

aiohttp/client.py

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
WSMessageTypeError,
7373
WSServerHandshakeError,
7474
)
75+
from .client_middlewares import ClientMiddlewareType, build_client_middlewares
7576
from .client_reqrep import (
7677
SSL_ALLOWED_TYPES,
7778
ClientRequest,
@@ -193,6 +194,7 @@ class _RequestOptions(TypedDict, total=False):
193194
auto_decompress: Union[bool, None]
194195
max_line_size: Union[int, None]
195196
max_field_size: Union[int, None]
197+
middlewares: Optional[Tuple[ClientMiddlewareType, ...]]
196198

197199

198200
@frozen_dataclass_decorator
@@ -260,6 +262,7 @@ class ClientSession:
260262
"_default_proxy",
261263
"_default_proxy_auth",
262264
"_retry_connection",
265+
"_middlewares",
263266
)
264267

265268
def __init__(
@@ -292,6 +295,7 @@ def __init__(
292295
max_line_size: int = 8190,
293296
max_field_size: int = 8190,
294297
fallback_charset_resolver: _CharsetResolver = lambda r, b: "utf-8",
298+
middlewares: Optional[Tuple[ClientMiddlewareType, ...]] = None,
295299
) -> None:
296300
# We initialise _connector to None immediately, as it's referenced in __del__()
297301
# and could cause issues if an exception occurs during initialisation.
@@ -376,6 +380,7 @@ def __init__(
376380
self._default_proxy = proxy
377381
self._default_proxy_auth = proxy_auth
378382
self._retry_connection: bool = True
383+
self._middlewares = middlewares
379384

380385
def __init_subclass__(cls: Type["ClientSession"]) -> None:
381386
raise TypeError(
@@ -450,6 +455,7 @@ async def _request(
450455
auto_decompress: Optional[bool] = None,
451456
max_line_size: Optional[int] = None,
452457
max_field_size: Optional[int] = None,
458+
middlewares: Optional[Tuple[ClientMiddlewareType, ...]] = None,
453459
) -> ClientResponse:
454460
# NOTE: timeout clamps existing connect and read timeouts. We cannot
455461
# set the default to None because we need to detect if the user wants
@@ -642,32 +648,33 @@ async def _request(
642648
trust_env=self.trust_env,
643649
)
644650

645-
# connection timeout
646-
try:
647-
conn = await self._connector.connect(
648-
req, traces=traces, timeout=real_timeout
651+
# Core request handler - now includes connection logic
652+
async def _connect_and_send_request(
653+
req: ClientRequest,
654+
) -> ClientResponse:
655+
# connection timeout
656+
assert self._connector is not None
657+
try:
658+
conn = await self._connector.connect(
659+
req, traces=traces, timeout=real_timeout
660+
)
661+
except asyncio.TimeoutError as exc:
662+
raise ConnectionTimeoutError(
663+
f"Connection timeout to host {req.url}"
664+
) from exc
665+
666+
assert conn.protocol is not None
667+
conn.protocol.set_response_params(
668+
timer=timer,
669+
skip_payload=req.method in EMPTY_BODY_METHODS,
670+
read_until_eof=read_until_eof,
671+
auto_decompress=auto_decompress,
672+
read_timeout=real_timeout.sock_read,
673+
read_bufsize=read_bufsize,
674+
timeout_ceil_threshold=self._connector._timeout_ceil_threshold,
675+
max_line_size=max_line_size,
676+
max_field_size=max_field_size,
649677
)
650-
except asyncio.TimeoutError as exc:
651-
raise ConnectionTimeoutError(
652-
f"Connection timeout to host {url}"
653-
) from exc
654-
655-
assert conn.transport is not None
656-
657-
assert conn.protocol is not None
658-
conn.protocol.set_response_params(
659-
timer=timer,
660-
skip_payload=method in EMPTY_BODY_METHODS,
661-
read_until_eof=read_until_eof,
662-
auto_decompress=auto_decompress,
663-
read_timeout=real_timeout.sock_read,
664-
read_bufsize=read_bufsize,
665-
timeout_ceil_threshold=self._connector._timeout_ceil_threshold,
666-
max_line_size=max_line_size,
667-
max_field_size=max_field_size,
668-
)
669-
670-
try:
671678
try:
672679
resp = await req.send(conn)
673680
try:
@@ -678,6 +685,30 @@ async def _request(
678685
except BaseException:
679686
conn.close()
680687
raise
688+
return resp
689+
690+
# Apply middleware (if any) - per-request middleware overrides session middleware
691+
effective_middlewares = (
692+
self._middlewares if middlewares is None else middlewares
693+
)
694+
695+
if effective_middlewares:
696+
handler = build_client_middlewares(
697+
_connect_and_send_request, effective_middlewares
698+
)
699+
else:
700+
handler = _connect_and_send_request
701+
702+
try:
703+
resp = await handler(req)
704+
# Client connector errors should not be retried
705+
except (
706+
ConnectionTimeoutError,
707+
ClientConnectorError,
708+
ClientConnectorCertificateError,
709+
ClientConnectorSSLError,
710+
):
711+
raise
681712
except (ClientOSError, ServerDisconnectedError):
682713
if retry_persistent_connection:
683714
retry_persistent_connection = False

aiohttp/client_middlewares.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""Client middleware support."""
2+
3+
from collections.abc import Awaitable, Callable
4+
5+
from .client_reqrep import ClientRequest, ClientResponse
6+
7+
__all__ = ("ClientMiddlewareType", "ClientHandlerType", "build_client_middlewares")
8+
9+
# Type alias for client request handlers - functions that process requests and return responses
10+
ClientHandlerType = Callable[[ClientRequest], Awaitable[ClientResponse]]
11+
12+
# Type for client middleware - similar to server but uses ClientRequest/ClientResponse
13+
ClientMiddlewareType = Callable[
14+
[ClientRequest, ClientHandlerType], Awaitable[ClientResponse]
15+
]
16+
17+
18+
def build_client_middlewares(
19+
handler: ClientHandlerType,
20+
middlewares: tuple[ClientMiddlewareType, ...],
21+
) -> ClientHandlerType:
22+
"""
23+
Apply middlewares to request handler.
24+
25+
The middlewares are applied in reverse order, so the first middleware
26+
in the list wraps all subsequent middlewares and the handler.
27+
28+
This implementation avoids using partial/update_wrapper to minimize overhead
29+
and doesn't cache to avoid holding references to stateful middleware.
30+
"""
31+
if not middlewares:
32+
return handler
33+
34+
# Optimize for single middleware case
35+
if len(middlewares) == 1:
36+
middleware = middlewares[0]
37+
38+
async def single_middleware_handler(req: ClientRequest) -> ClientResponse:
39+
return await middleware(req, handler)
40+
41+
return single_middleware_handler
42+
43+
# Build the chain for multiple middlewares
44+
current_handler = handler
45+
46+
for middleware in reversed(middlewares):
47+
# Create a new closure that captures the current state
48+
def make_wrapper(
49+
mw: ClientMiddlewareType, next_h: ClientHandlerType
50+
) -> ClientHandlerType:
51+
async def wrapped(req: ClientRequest) -> ClientResponse:
52+
return await mw(req, next_h)
53+
54+
return wrapped
55+
56+
current_handler = make_wrapper(middleware, current_handler)
57+
58+
return current_handler

aiohttp/client_reqrep.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,11 @@ class ClientRequest:
210210
auth = None
211211
response = None
212212

213+
# These class defaults help create_autospec() work correctly.
214+
# If autospec is improved in future, maybe these can be removed.
215+
url = URL()
216+
method = "GET"
217+
213218
__writer: Optional["asyncio.Task[None]"] = None # async task for streaming data
214219
_continue = None # waiter future for '100 Continue' response
215220

@@ -362,6 +367,16 @@ def request_info(self) -> RequestInfo:
362367
RequestInfo, (self.url, self.method, headers, self.original_url)
363368
)
364369

370+
@property
371+
def session(self) -> "ClientSession":
372+
"""Return the ClientSession instance.
373+
374+
This property provides access to the ClientSession that initiated
375+
this request, allowing middleware to make additional requests
376+
using the same session.
377+
"""
378+
return self._session
379+
365380
def update_host(self, url: URL) -> None:
366381
"""Update destination host, port and connection type (ssl)."""
367382
# get host/port

0 commit comments

Comments
 (0)