diff --git a/src/pook/interceptors/aiohttp.py b/src/pook/interceptors/aiohttp.py index 5646513..be5a9ac 100644 --- a/src/pook/interceptors/aiohttp.py +++ b/src/pook/interceptors/aiohttp.py @@ -1,9 +1,7 @@ -import asyncio from http.client import responses as http_reasons -from typing import Callable, Optional +from typing import Optional from unittest import mock -from urllib.parse import urlencode, urlunparse -from collections.abc import Mapping +from urllib.parse import urlunparse import aiohttp from aiohttp.helpers import TimerNoop @@ -16,111 +14,29 @@ import multidict import yarl -PATCHES = ("aiohttp.client.ClientSession._request",) - RESPONSE_CLASS = "ClientResponse" RESPONSE_PATH = "aiohttp.client_reqrep" -class SimpleContent(EmptyStreamReader): - def __init__(self, content, *args, **kwargs): - super().__init__(*args, **kwargs) - self.content = content - - async def read(self, n=-1): - return self.content - - -def HTTPResponse(session: aiohttp.ClientSession, *args, **kw): - return session._response_class( - *args, - request_info=mock.Mock(), - writer=None, - continue100=None, - timer=TimerNoop(), - traces=[], - loop=mock.Mock(), - session=mock.Mock(), - **kw, - ) - - class AIOHTTPInterceptor(BaseInterceptor): - """ - aiohttp HTTP client traffic interceptor. - """ - - def _url(self, url) -> Optional[yarl.URL]: - return yarl.URL(url) if yarl else None - - def set_headers(self, req, headers) -> None: - # aiohttp's interface allows various mappings, as well as an iterable of key/value tuples - # ``pook.request`` only allows a dict, so we need to map the iterable to the matchable interface - if headers: - if isinstance(headers, Mapping): - req.headers.update(**headers) - else: - # If it isn't a mapping, then its an Iterable[Tuple[Union[str, istr], str]] - for req_header, req_header_value in headers: - normalised_header = req_header.lower() - if normalised_header in req.headers: - req.headers[normalised_header] += f", {req_header_value}" - else: - req.headers[normalised_header] = req_header_value - - async def _on_request( - self, - _request: Callable, - session: aiohttp.ClientSession, - method: str, - url: str, - data=None, - headers=None, - **kw, + # Implements aiohttp.ClientMiddlewareType + async def __call__( + self, request: aiohttp.ClientRequest, handler: aiohttp.ClientHandlerType ) -> aiohttp.ClientResponse: - # Create request contract based on incoming params - req = Request(method) - - self.set_headers(req, headers) - self.set_headers(req, session.headers) - - req.body = data - - # Expose extra variadic arguments - req.extra = kw - - full_url = session._build_url(url) + req = Request( + method=request.method, + headers=request.headers.items(), + body=request.body, + url=str(request.url), + ) - # Compose URL - if not kw.get("params"): - req.url = str(full_url) - else: - # Transform params as a list of tuple - params = kw["params"] - if isinstance(params, dict): - params = [(x, y) for x, y in kw["params"].items()] - req.url = str(full_url) + "?" + urlencode(params) - - # If a json payload is provided, serialize it for JSONMatcher support - if json_body := kw.get("json"): - req.json = json_body - if "Content-Type" not in req.headers: - req.headers["Content-Type"] = "application/json" - - # Match the request against the registered mocks in pook mock = self.engine.match(req) # If cannot match any mock, run real HTTP request if networking # or silent model are enabled, otherwise this statement won't # be reached (an exception will be raised before). if not mock: - return await _request( - session, method, url, data=data, headers=headers, **kw - ) - - # Simulate network delay - if mock._delay: - await asyncio.sleep(mock._delay / 1000) # noqa + return await handler(request) # Shortcut to mock response res = mock._response @@ -131,7 +47,7 @@ async def _on_request( headers.append((key, res._headers[key])) # Create mock equivalent HTTP response - _res = HTTPResponse(session, req.method, self._url(urlunparse(req.url))) + _res = HTTPResponse(request.session, req.method, self._url(urlunparse(req.url))) # response status _res.version = aiohttp.HttpVersion(1, 1) @@ -154,23 +70,24 @@ async def _on_request( # Return response based on mock definition return _res - def _patch(self, path: str) -> None: + def _url(self, url) -> Optional[yarl.URL]: + return yarl.URL(url) if yarl else None + + def activate(self) -> None: # If not able to import aiohttp dependencies, skip if not yarl or not multidict: return None - async def handler(session, method, url, data=None, headers=None, **kw): - return await self._on_request( - _request, session, method, url, data=data, headers=headers, **kw - ) + def _request(session, *args, **kwargs): + request_middlewares = kwargs.get("middlewares", ()) + kwargs["middlewares"] = request_middlewares + (self,) + return super_request(session, *args, **kwargs) try: - # Create a new patcher for Urllib3 urlopen function - # used as entry point for all the HTTP communications - patcher = mock.patch(path, handler) - # Retrieve original patched function that we might need for real - # networking - _request = patcher.get_original()[0] + # Patch ClientSession init to append this interceptor as an aiohttp + # middleware to all session's middlewares + patcher = mock.patch("aiohttp.client.ClientSession._request", _request) + super_request = patcher.get_original()[0] # Start patching function calls patcher.start() except Exception: @@ -180,14 +97,6 @@ async def handler(session, method, url, data=None, headers=None, **kw): else: self.patchers.append(patcher) - def activate(self) -> None: - """ - Activates the traffic interceptor. - This method must be implemented by any interceptor. - """ - for path in PATCHES: - self._patch(path) - def disable(self) -> None: """ Disables the traffic interceptor. @@ -195,3 +104,26 @@ def disable(self) -> None: """ for patch in self.patchers: patch.stop() + + +class SimpleContent(EmptyStreamReader): + def __init__(self, content, *args, **kwargs): + super().__init__(*args, **kwargs) + self.content = content + + async def read(self, n=-1): + return self.content + + +def HTTPResponse(session: aiohttp.ClientSession, *args, **kw): + return session._response_class( + *args, + request_info=mock.Mock(), + writer=None, + continue100=None, + timer=TimerNoop(), + traces=[], + loop=mock.Mock(), + session=mock.Mock(), + **kw, + ) diff --git a/src/pook/request.py b/src/pook/request.py index 76ec43f..14940e5 100644 --- a/src/pook/request.py +++ b/src/pook/request.py @@ -54,8 +54,6 @@ def headers(self): @headers.setter def headers(self, headers): - if not hasattr(headers, "__setitem__"): - raise TypeError("headers must be a dictionary") self._headers.extend(headers) @property diff --git a/tests/unit/interceptors/aiohttp_test.py b/tests/unit/interceptors/aiohttp_test.py index c4bb193..0429eb1 100644 --- a/tests/unit/interceptors/aiohttp_test.py +++ b/tests/unit/interceptors/aiohttp_test.py @@ -97,6 +97,22 @@ async def test_client_headers_merged(local_responder): assert await res.read() == b"hello from pook" +@pytest.mark.asyncio +async def test_client_auth_merged(local_responder): + """Auth headers set on the client should be matched""" + pook.get(local_responder + "/status/404").header( + "Authorization", "Basic dXNlcjpwYXNzd29yZA==" + ).reply(200).body("hello from pook") + async with aiohttp.ClientSession( + auth=aiohttp.BasicAuth("user", "password") + ) as session: + res = await session.get( + local_responder + "/status/404", headers={"x-pook-secondary": "xyz"} + ) + assert res.status == 200 + assert await res.read() == b"hello from pook" + + @pytest.mark.asyncio async def test_client_headers_both_session_and_request(local_responder): """Headers should be matchable from both the session and request in the same matcher"""