Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 48 additions & 116 deletions src/pook/interceptors/aiohttp.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import asyncio
from http.client import responses as http_reasons
from typing import Callable, Optional
from typing import Optional
from unittest import mock
from urllib.parse import urlencode, urlunparse
from collections.abc import Mapping
from urllib.parse import urlunparse

import aiohttp
from aiohttp.helpers import TimerNoop
Expand All @@ -16,111 +14,29 @@
import multidict
import yarl

PATCHES = ("aiohttp.client.ClientSession._request",)

RESPONSE_CLASS = "ClientResponse"
RESPONSE_PATH = "aiohttp.client_reqrep"


class SimpleContent(EmptyStreamReader):
def __init__(self, content, *args, **kwargs):
super().__init__(*args, **kwargs)
self.content = content

async def read(self, n=-1):
return self.content


def HTTPResponse(session: aiohttp.ClientSession, *args, **kw):
return session._response_class(
*args,
request_info=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
loop=mock.Mock(),
session=mock.Mock(),
**kw,
)


class AIOHTTPInterceptor(BaseInterceptor):
"""
aiohttp HTTP client traffic interceptor.
"""

def _url(self, url) -> Optional[yarl.URL]:
return yarl.URL(url) if yarl else None

def set_headers(self, req, headers) -> None:
# aiohttp's interface allows various mappings, as well as an iterable of key/value tuples
# ``pook.request`` only allows a dict, so we need to map the iterable to the matchable interface
if headers:
if isinstance(headers, Mapping):
req.headers.update(**headers)
else:
# If it isn't a mapping, then its an Iterable[Tuple[Union[str, istr], str]]
for req_header, req_header_value in headers:
normalised_header = req_header.lower()
if normalised_header in req.headers:
req.headers[normalised_header] += f", {req_header_value}"
else:
req.headers[normalised_header] = req_header_value

async def _on_request(
self,
_request: Callable,
session: aiohttp.ClientSession,
method: str,
url: str,
data=None,
headers=None,
**kw,
# Implements aiohttp.ClientMiddlewareType
async def __call__(
self, request: aiohttp.ClientRequest, handler: aiohttp.ClientHandlerType
) -> aiohttp.ClientResponse:
# Create request contract based on incoming params
req = Request(method)

self.set_headers(req, headers)
self.set_headers(req, session.headers)

req.body = data

# Expose extra variadic arguments
req.extra = kw

full_url = session._build_url(url)
req = Request(
method=request.method,
headers=request.headers.items(),
body=request.body,
url=str(request.url),
)

# Compose URL
if not kw.get("params"):
req.url = str(full_url)
else:
# Transform params as a list of tuple
params = kw["params"]
if isinstance(params, dict):
params = [(x, y) for x, y in kw["params"].items()]
req.url = str(full_url) + "?" + urlencode(params)

# If a json payload is provided, serialize it for JSONMatcher support
if json_body := kw.get("json"):
req.json = json_body
if "Content-Type" not in req.headers:
req.headers["Content-Type"] = "application/json"

# Match the request against the registered mocks in pook
mock = self.engine.match(req)

# If cannot match any mock, run real HTTP request if networking
# or silent model are enabled, otherwise this statement won't
# be reached (an exception will be raised before).
if not mock:
return await _request(
session, method, url, data=data, headers=headers, **kw
)

# Simulate network delay
Copy link
Copy Markdown

@hovsater hovsater Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we accidentally break support for network delay simulation? I can't see any equivalent in the new implementation.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in #178.

if mock._delay:
await asyncio.sleep(mock._delay / 1000) # noqa
return await handler(request)

# Shortcut to mock response
res = mock._response
Expand All @@ -131,7 +47,7 @@ async def _on_request(
headers.append((key, res._headers[key]))

# Create mock equivalent HTTP response
_res = HTTPResponse(session, req.method, self._url(urlunparse(req.url)))
_res = HTTPResponse(request.session, req.method, self._url(urlunparse(req.url)))

# response status
_res.version = aiohttp.HttpVersion(1, 1)
Expand All @@ -154,23 +70,24 @@ async def _on_request(
# Return response based on mock definition
return _res

def _patch(self, path: str) -> None:
def _url(self, url) -> Optional[yarl.URL]:
return yarl.URL(url) if yarl else None

def activate(self) -> None:
# If not able to import aiohttp dependencies, skip
if not yarl or not multidict:
return None

async def handler(session, method, url, data=None, headers=None, **kw):
return await self._on_request(
_request, session, method, url, data=data, headers=headers, **kw
)
def _request(session, *args, **kwargs):
request_middlewares = kwargs.get("middlewares", ())
kwargs["middlewares"] = request_middlewares + (self,)
return super_request(session, *args, **kwargs)

try:
# Create a new patcher for Urllib3 urlopen function
# used as entry point for all the HTTP communications
patcher = mock.patch(path, handler)
# Retrieve original patched function that we might need for real
# networking
_request = patcher.get_original()[0]
# Patch ClientSession init to append this interceptor as an aiohttp
# middleware to all session's middlewares
patcher = mock.patch("aiohttp.client.ClientSession._request", _request)
super_request = patcher.get_original()[0]
# Start patching function calls
patcher.start()
except Exception:
Expand All @@ -180,18 +97,33 @@ async def handler(session, method, url, data=None, headers=None, **kw):
else:
self.patchers.append(patcher)

def activate(self) -> None:
"""
Activates the traffic interceptor.
This method must be implemented by any interceptor.
"""
for path in PATCHES:
self._patch(path)

def disable(self) -> None:
"""
Disables the traffic interceptor.
This method must be implemented by any interceptor.
"""
for patch in self.patchers:
patch.stop()


class SimpleContent(EmptyStreamReader):
def __init__(self, content, *args, **kwargs):
super().__init__(*args, **kwargs)
self.content = content

async def read(self, n=-1):
return self.content


def HTTPResponse(session: aiohttp.ClientSession, *args, **kw):
return session._response_class(
*args,
request_info=mock.Mock(),
writer=None,
continue100=None,
timer=TimerNoop(),
traces=[],
loop=mock.Mock(),
session=mock.Mock(),
**kw,
)
2 changes: 0 additions & 2 deletions src/pook/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ def headers(self):

@headers.setter
def headers(self, headers):
if not hasattr(headers, "__setitem__"):
raise TypeError("headers must be a dictionary")
Comment on lines -57 to -58
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check isn't necessary, or even accurate. The underlying HTTPHeaderDict class's extend method handles a host of different types of inputs and should be the source of any relevant TypeErrors anyway.

self._headers.extend(headers)

@property
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/interceptors/aiohttp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,22 @@ async def test_client_headers_merged(local_responder):
assert await res.read() == b"hello from pook"


@pytest.mark.asyncio
async def test_client_auth_merged(local_responder):
"""Auth headers set on the client should be matched"""
pook.get(local_responder + "/status/404").header(
"Authorization", "Basic dXNlcjpwYXNzd29yZA=="
).reply(200).body("hello from pook")
async with aiohttp.ClientSession(
auth=aiohttp.BasicAuth("user", "password")
) as session:
res = await session.get(
local_responder + "/status/404", headers={"x-pook-secondary": "xyz"}
)
assert res.status == 200
assert await res.read() == b"hello from pook"


@pytest.mark.asyncio
async def test_client_headers_both_session_and_request(local_responder):
"""Headers should be matchable from both the session and request in the same matcher"""
Expand Down