Skip to content

Commit b431197

Browse files
Use aiohttp middleware for mocking (#170)
This solves a host of issues, including #166, as well as that middlwares would not have previously worked. It might also solve #152 but I have not tested that yet.
1 parent 08fe427 commit b431197

File tree

3 files changed

+64
-118
lines changed

3 files changed

+64
-118
lines changed

src/pook/interceptors/aiohttp.py

Lines changed: 48 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
import asyncio
21
from http.client import responses as http_reasons
3-
from typing import Callable, Optional
2+
from typing import Optional
43
from unittest import mock
5-
from urllib.parse import urlencode, urlunparse
6-
from collections.abc import Mapping
4+
from urllib.parse import urlunparse
75

86
import aiohttp
97
from aiohttp.helpers import TimerNoop
@@ -16,111 +14,29 @@
1614
import multidict
1715
import yarl
1816

19-
PATCHES = ("aiohttp.client.ClientSession._request",)
20-
2117
RESPONSE_CLASS = "ClientResponse"
2218
RESPONSE_PATH = "aiohttp.client_reqrep"
2319

2420

25-
class SimpleContent(EmptyStreamReader):
26-
def __init__(self, content, *args, **kwargs):
27-
super().__init__(*args, **kwargs)
28-
self.content = content
29-
30-
async def read(self, n=-1):
31-
return self.content
32-
33-
34-
def HTTPResponse(session: aiohttp.ClientSession, *args, **kw):
35-
return session._response_class(
36-
*args,
37-
request_info=mock.Mock(),
38-
writer=None,
39-
continue100=None,
40-
timer=TimerNoop(),
41-
traces=[],
42-
loop=mock.Mock(),
43-
session=mock.Mock(),
44-
**kw,
45-
)
46-
47-
4821
class AIOHTTPInterceptor(BaseInterceptor):
49-
"""
50-
aiohttp HTTP client traffic interceptor.
51-
"""
52-
53-
def _url(self, url) -> Optional[yarl.URL]:
54-
return yarl.URL(url) if yarl else None
55-
56-
def set_headers(self, req, headers) -> None:
57-
# aiohttp's interface allows various mappings, as well as an iterable of key/value tuples
58-
# ``pook.request`` only allows a dict, so we need to map the iterable to the matchable interface
59-
if headers:
60-
if isinstance(headers, Mapping):
61-
req.headers.update(**headers)
62-
else:
63-
# If it isn't a mapping, then its an Iterable[Tuple[Union[str, istr], str]]
64-
for req_header, req_header_value in headers:
65-
normalised_header = req_header.lower()
66-
if normalised_header in req.headers:
67-
req.headers[normalised_header] += f", {req_header_value}"
68-
else:
69-
req.headers[normalised_header] = req_header_value
70-
71-
async def _on_request(
72-
self,
73-
_request: Callable,
74-
session: aiohttp.ClientSession,
75-
method: str,
76-
url: str,
77-
data=None,
78-
headers=None,
79-
**kw,
22+
# Implements aiohttp.ClientMiddlewareType
23+
async def __call__(
24+
self, request: aiohttp.ClientRequest, handler: aiohttp.ClientHandlerType
8025
) -> aiohttp.ClientResponse:
81-
# Create request contract based on incoming params
82-
req = Request(method)
83-
84-
self.set_headers(req, headers)
85-
self.set_headers(req, session.headers)
86-
87-
req.body = data
88-
89-
# Expose extra variadic arguments
90-
req.extra = kw
91-
92-
full_url = session._build_url(url)
26+
req = Request(
27+
method=request.method,
28+
headers=request.headers.items(),
29+
body=request.body,
30+
url=str(request.url),
31+
)
9332

94-
# Compose URL
95-
if not kw.get("params"):
96-
req.url = str(full_url)
97-
else:
98-
# Transform params as a list of tuple
99-
params = kw["params"]
100-
if isinstance(params, dict):
101-
params = [(x, y) for x, y in kw["params"].items()]
102-
req.url = str(full_url) + "?" + urlencode(params)
103-
104-
# If a json payload is provided, serialize it for JSONMatcher support
105-
if json_body := kw.get("json"):
106-
req.json = json_body
107-
if "Content-Type" not in req.headers:
108-
req.headers["Content-Type"] = "application/json"
109-
110-
# Match the request against the registered mocks in pook
11133
mock = self.engine.match(req)
11234

11335
# If cannot match any mock, run real HTTP request if networking
11436
# or silent model are enabled, otherwise this statement won't
11537
# be reached (an exception will be raised before).
11638
if not mock:
117-
return await _request(
118-
session, method, url, data=data, headers=headers, **kw
119-
)
120-
121-
# Simulate network delay
122-
if mock._delay:
123-
await asyncio.sleep(mock._delay / 1000) # noqa
39+
return await handler(request)
12440

12541
# Shortcut to mock response
12642
res = mock._response
@@ -131,7 +47,7 @@ async def _on_request(
13147
headers.append((key, res._headers[key]))
13248

13349
# Create mock equivalent HTTP response
134-
_res = HTTPResponse(session, req.method, self._url(urlunparse(req.url)))
50+
_res = HTTPResponse(request.session, req.method, self._url(urlunparse(req.url)))
13551

13652
# response status
13753
_res.version = aiohttp.HttpVersion(1, 1)
@@ -154,23 +70,24 @@ async def _on_request(
15470
# Return response based on mock definition
15571
return _res
15672

157-
def _patch(self, path: str) -> None:
73+
def _url(self, url) -> Optional[yarl.URL]:
74+
return yarl.URL(url) if yarl else None
75+
76+
def activate(self) -> None:
15877
# If not able to import aiohttp dependencies, skip
15978
if not yarl or not multidict:
16079
return None
16180

162-
async def handler(session, method, url, data=None, headers=None, **kw):
163-
return await self._on_request(
164-
_request, session, method, url, data=data, headers=headers, **kw
165-
)
81+
def _request(session, *args, **kwargs):
82+
request_middlewares = kwargs.get("middlewares", ())
83+
kwargs["middlewares"] = request_middlewares + (self,)
84+
return super_request(session, *args, **kwargs)
16685

16786
try:
168-
# Create a new patcher for Urllib3 urlopen function
169-
# used as entry point for all the HTTP communications
170-
patcher = mock.patch(path, handler)
171-
# Retrieve original patched function that we might need for real
172-
# networking
173-
_request = patcher.get_original()[0]
87+
# Patch ClientSession init to append this interceptor as an aiohttp
88+
# middleware to all session's middlewares
89+
patcher = mock.patch("aiohttp.client.ClientSession._request", _request)
90+
super_request = patcher.get_original()[0]
17491
# Start patching function calls
17592
patcher.start()
17693
except Exception:
@@ -180,18 +97,33 @@ async def handler(session, method, url, data=None, headers=None, **kw):
18097
else:
18198
self.patchers.append(patcher)
18299

183-
def activate(self) -> None:
184-
"""
185-
Activates the traffic interceptor.
186-
This method must be implemented by any interceptor.
187-
"""
188-
for path in PATCHES:
189-
self._patch(path)
190-
191100
def disable(self) -> None:
192101
"""
193102
Disables the traffic interceptor.
194103
This method must be implemented by any interceptor.
195104
"""
196105
for patch in self.patchers:
197106
patch.stop()
107+
108+
109+
class SimpleContent(EmptyStreamReader):
110+
def __init__(self, content, *args, **kwargs):
111+
super().__init__(*args, **kwargs)
112+
self.content = content
113+
114+
async def read(self, n=-1):
115+
return self.content
116+
117+
118+
def HTTPResponse(session: aiohttp.ClientSession, *args, **kw):
119+
return session._response_class(
120+
*args,
121+
request_info=mock.Mock(),
122+
writer=None,
123+
continue100=None,
124+
timer=TimerNoop(),
125+
traces=[],
126+
loop=mock.Mock(),
127+
session=mock.Mock(),
128+
**kw,
129+
)

src/pook/request.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ def headers(self):
5454

5555
@headers.setter
5656
def headers(self, headers):
57-
if not hasattr(headers, "__setitem__"):
58-
raise TypeError("headers must be a dictionary")
5957
self._headers.extend(headers)
6058

6159
@property

tests/unit/interceptors/aiohttp_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,22 @@ async def test_client_headers_merged(local_responder):
9797
assert await res.read() == b"hello from pook"
9898

9999

100+
@pytest.mark.asyncio
101+
async def test_client_auth_merged(local_responder):
102+
"""Auth headers set on the client should be matched"""
103+
pook.get(local_responder + "/status/404").header(
104+
"Authorization", "Basic dXNlcjpwYXNzd29yZA=="
105+
).reply(200).body("hello from pook")
106+
async with aiohttp.ClientSession(
107+
auth=aiohttp.BasicAuth("user", "password")
108+
) as session:
109+
res = await session.get(
110+
local_responder + "/status/404", headers={"x-pook-secondary": "xyz"}
111+
)
112+
assert res.status == 200
113+
assert await res.read() == b"hello from pook"
114+
115+
100116
@pytest.mark.asyncio
101117
async def test_client_headers_both_session_and_request(local_responder):
102118
"""Headers should be matchable from both the session and request in the same matcher"""

0 commit comments

Comments
 (0)