Skip to content

Commit 5044d53

Browse files
bdracoDreamsorcererpre-commit-ci[bot]
authored
[PR #9732/1e911ea backport][3.12] Add Client Middleware Support (#10879)
Co-authored-by: Sam Bull <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6ea542e commit 5044d53

File tree

9 files changed

+1512
-28
lines changed

9 files changed

+1512
-28
lines changed

CHANGES/9732.feature.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Added client middleware support -- by :user:`bdraco` and :user:`Dreamsorcerer`.
2+
3+
This change allows users to add middleware to the client session and requests, enabling features like
4+
authentication, logging, and request/response modification without modifying the core
5+
request logic. Additionally, the ``session`` attribute was added to ``ClientRequest``,
6+
allowing middleware to access the session for making additional requests.

aiohttp/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
WSServerHandshakeError,
4848
request,
4949
)
50+
from .client_middlewares import ClientHandlerType, ClientMiddlewareType
5051
from .compression_utils import set_zlib_backend
5152
from .connector import (
5253
AddrInfoType as AddrInfoType,
@@ -175,6 +176,9 @@
175176
"NamedPipeConnector",
176177
"WSServerHandshakeError",
177178
"request",
179+
# client_middleware
180+
"ClientMiddlewareType",
181+
"ClientHandlerType",
178182
# cookiejar
179183
"CookieJar",
180184
"DummyCookieJar",

aiohttp/client.py

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
WSMessageTypeError,
7171
WSServerHandshakeError,
7272
)
73+
from .client_middlewares import ClientMiddlewareType, build_client_middlewares
7374
from .client_reqrep import (
7475
ClientRequest as ClientRequest,
7576
ClientResponse as ClientResponse,
@@ -191,6 +192,7 @@ class _RequestOptions(TypedDict, total=False):
191192
auto_decompress: Union[bool, None]
192193
max_line_size: Union[int, None]
193194
max_field_size: Union[int, None]
195+
middlewares: Optional[Tuple[ClientMiddlewareType, ...]]
194196

195197

196198
@attr.s(auto_attribs=True, frozen=True, slots=True)
@@ -258,6 +260,7 @@ class ClientSession:
258260
"_default_proxy",
259261
"_default_proxy_auth",
260262
"_retry_connection",
263+
"_middlewares",
261264
"requote_redirect_url",
262265
]
263266
)
@@ -298,6 +301,7 @@ def __init__(
298301
max_line_size: int = 8190,
299302
max_field_size: int = 8190,
300303
fallback_charset_resolver: _CharsetResolver = lambda r, b: "utf-8",
304+
middlewares: Optional[Tuple[ClientMiddlewareType, ...]] = None,
301305
) -> None:
302306
# We initialise _connector to None immediately, as it's referenced in __del__()
303307
# and could cause issues if an exception occurs during initialisation.
@@ -410,6 +414,7 @@ def __init__(
410414
self._default_proxy = proxy
411415
self._default_proxy_auth = proxy_auth
412416
self._retry_connection: bool = True
417+
self._middlewares = middlewares
413418

414419
def __init_subclass__(cls: Type["ClientSession"]) -> None:
415420
warnings.warn(
@@ -500,6 +505,7 @@ async def _request(
500505
auto_decompress: Optional[bool] = None,
501506
max_line_size: Optional[int] = None,
502507
max_field_size: Optional[int] = None,
508+
middlewares: Optional[Tuple[ClientMiddlewareType, ...]] = None,
503509
) -> ClientResponse:
504510

505511
# NOTE: timeout clamps existing connect and read timeouts. We cannot
@@ -699,32 +705,33 @@ async def _request(
699705
trust_env=self.trust_env,
700706
)
701707

702-
# connection timeout
703-
try:
704-
conn = await self._connector.connect(
705-
req, traces=traces, timeout=real_timeout
708+
# Core request handler - now includes connection logic
709+
async def _connect_and_send_request(
710+
req: ClientRequest,
711+
) -> ClientResponse:
712+
# connection timeout
713+
assert self._connector is not None
714+
try:
715+
conn = await self._connector.connect(
716+
req, traces=traces, timeout=real_timeout
717+
)
718+
except asyncio.TimeoutError as exc:
719+
raise ConnectionTimeoutError(
720+
f"Connection timeout to host {req.url}"
721+
) from exc
722+
723+
assert conn.protocol is not None
724+
conn.protocol.set_response_params(
725+
timer=timer,
726+
skip_payload=req.method in EMPTY_BODY_METHODS,
727+
read_until_eof=read_until_eof,
728+
auto_decompress=auto_decompress,
729+
read_timeout=real_timeout.sock_read,
730+
read_bufsize=read_bufsize,
731+
timeout_ceil_threshold=self._connector._timeout_ceil_threshold,
732+
max_line_size=max_line_size,
733+
max_field_size=max_field_size,
706734
)
707-
except asyncio.TimeoutError as exc:
708-
raise ConnectionTimeoutError(
709-
f"Connection timeout to host {url}"
710-
) from exc
711-
712-
assert conn.transport is not None
713-
714-
assert conn.protocol is not None
715-
conn.protocol.set_response_params(
716-
timer=timer,
717-
skip_payload=method in EMPTY_BODY_METHODS,
718-
read_until_eof=read_until_eof,
719-
auto_decompress=auto_decompress,
720-
read_timeout=real_timeout.sock_read,
721-
read_bufsize=read_bufsize,
722-
timeout_ceil_threshold=self._connector._timeout_ceil_threshold,
723-
max_line_size=max_line_size,
724-
max_field_size=max_field_size,
725-
)
726-
727-
try:
728735
try:
729736
resp = await req.send(conn)
730737
try:
@@ -735,6 +742,30 @@ async def _request(
735742
except BaseException:
736743
conn.close()
737744
raise
745+
return resp
746+
747+
# Apply middleware (if any) - per-request middleware overrides session middleware
748+
effective_middlewares = (
749+
self._middlewares if middlewares is None else middlewares
750+
)
751+
752+
if effective_middlewares:
753+
handler = build_client_middlewares(
754+
_connect_and_send_request, effective_middlewares
755+
)
756+
else:
757+
handler = _connect_and_send_request
758+
759+
try:
760+
resp = await handler(req)
761+
# Client connector errors should not be retried
762+
except (
763+
ConnectionTimeoutError,
764+
ClientConnectorError,
765+
ClientConnectorCertificateError,
766+
ClientConnectorSSLError,
767+
):
768+
raise
738769
except (ClientOSError, ServerDisconnectedError):
739770
if retry_persistent_connection:
740771
retry_persistent_connection = False

aiohttp/client_middlewares.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""Client middleware support."""
2+
3+
from collections.abc import Awaitable, Callable
4+
5+
from .client_reqrep import ClientRequest, ClientResponse
6+
7+
__all__ = ("ClientMiddlewareType", "ClientHandlerType", "build_client_middlewares")
8+
9+
# Type alias for client request handlers - functions that process requests and return responses
10+
ClientHandlerType = Callable[[ClientRequest], Awaitable[ClientResponse]]
11+
12+
# Type for client middleware - similar to server but uses ClientRequest/ClientResponse
13+
ClientMiddlewareType = Callable[
14+
[ClientRequest, ClientHandlerType], Awaitable[ClientResponse]
15+
]
16+
17+
18+
def build_client_middlewares(
19+
handler: ClientHandlerType,
20+
middlewares: tuple[ClientMiddlewareType, ...],
21+
) -> ClientHandlerType:
22+
"""
23+
Apply middlewares to request handler.
24+
25+
The middlewares are applied in reverse order, so the first middleware
26+
in the list wraps all subsequent middlewares and the handler.
27+
28+
This implementation avoids using partial/update_wrapper to minimize overhead
29+
and doesn't cache to avoid holding references to stateful middleware.
30+
"""
31+
if not middlewares:
32+
return handler
33+
34+
# Optimize for single middleware case
35+
if len(middlewares) == 1:
36+
middleware = middlewares[0]
37+
38+
async def single_middleware_handler(req: ClientRequest) -> ClientResponse:
39+
return await middleware(req, handler)
40+
41+
return single_middleware_handler
42+
43+
# Build the chain for multiple middlewares
44+
current_handler = handler
45+
46+
for middleware in reversed(middlewares):
47+
# Create a new closure that captures the current state
48+
def make_wrapper(
49+
mw: ClientMiddlewareType, next_h: ClientHandlerType
50+
) -> ClientHandlerType:
51+
async def wrapped(req: ClientRequest) -> ClientResponse:
52+
return await mw(req, next_h)
53+
54+
return wrapped
55+
56+
current_handler = make_wrapper(middleware, current_handler)
57+
58+
return current_handler

aiohttp/client_reqrep.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,13 @@ class ClientRequest:
272272
auth = None
273273
response = None
274274

275-
__writer = None # async task for streaming data
275+
__writer: Optional["asyncio.Task[None]"] = None # async task for streaming data
276+
277+
# These class defaults help create_autospec() work correctly.
278+
# If autospec is improved in future, maybe these can be removed.
279+
url = URL()
280+
method = "GET"
281+
276282
_continue = None # waiter future for '100 Continue' response
277283

278284
_skip_auto_headers: Optional["CIMultiDict[None]"] = None
@@ -427,6 +433,16 @@ def request_info(self) -> RequestInfo:
427433
RequestInfo, (self.url, self.method, headers, self.original_url)
428434
)
429435

436+
@property
437+
def session(self) -> "ClientSession":
438+
"""Return the ClientSession instance.
439+
440+
This property provides access to the ClientSession that initiated
441+
this request, allowing middleware to make additional requests
442+
using the same session.
443+
"""
444+
return self._session
445+
430446
def update_host(self, url: URL) -> None:
431447
"""Update destination host, port and connection type (ssl)."""
432448
# get host/port

0 commit comments

Comments
 (0)