1- import asyncio
21from http .client import responses as http_reasons
3- from typing import Callable , Optional
2+ from typing import Optional
43from unittest import mock
5- from urllib .parse import urlencode , urlunparse
6- from collections .abc import Mapping
4+ from urllib .parse import urlunparse
75
86import aiohttp
97from aiohttp .helpers import TimerNoop
1614import multidict
1715import yarl
1816
19- PATCHES = ("aiohttp.client.ClientSession._request" ,)
20-
2117RESPONSE_CLASS = "ClientResponse"
2218RESPONSE_PATH = "aiohttp.client_reqrep"
2319
2420
25- class SimpleContent (EmptyStreamReader ):
26- def __init__ (self , content , * args , ** kwargs ):
27- super ().__init__ (* args , ** kwargs )
28- self .content = content
29-
30- async def read (self , n = - 1 ):
31- return self .content
32-
33-
34- def HTTPResponse (session : aiohttp .ClientSession , * args , ** kw ):
35- return session ._response_class (
36- * args ,
37- request_info = mock .Mock (),
38- writer = None ,
39- continue100 = None ,
40- timer = TimerNoop (),
41- traces = [],
42- loop = mock .Mock (),
43- session = mock .Mock (),
44- ** kw ,
45- )
46-
47-
4821class AIOHTTPInterceptor (BaseInterceptor ):
49- """
50- aiohttp HTTP client traffic interceptor.
51- """
52-
53- def _url (self , url ) -> Optional [yarl .URL ]:
54- return yarl .URL (url ) if yarl else None
55-
56- def set_headers (self , req , headers ) -> None :
57- # aiohttp's interface allows various mappings, as well as an iterable of key/value tuples
58- # ``pook.request`` only allows a dict, so we need to map the iterable to the matchable interface
59- if headers :
60- if isinstance (headers , Mapping ):
61- req .headers .update (** headers )
62- else :
63- # If it isn't a mapping, then its an Iterable[Tuple[Union[str, istr], str]]
64- for req_header , req_header_value in headers :
65- normalised_header = req_header .lower ()
66- if normalised_header in req .headers :
67- req .headers [normalised_header ] += f", { req_header_value } "
68- else :
69- req .headers [normalised_header ] = req_header_value
70-
71- async def _on_request (
72- self ,
73- _request : Callable ,
74- session : aiohttp .ClientSession ,
75- method : str ,
76- url : str ,
77- data = None ,
78- headers = None ,
79- ** kw ,
22+ # Implements aiohttp.ClientMiddlewareType
23+ async def __call__ (
24+ self , request : aiohttp .ClientRequest , handler : aiohttp .ClientHandlerType
8025 ) -> aiohttp .ClientResponse :
81- # Create request contract based on incoming params
82- req = Request (method )
83-
84- self .set_headers (req , headers )
85- self .set_headers (req , session .headers )
86-
87- req .body = data
88-
89- # Expose extra variadic arguments
90- req .extra = kw
91-
92- full_url = session ._build_url (url )
26+ req = Request (
27+ method = request .method ,
28+ headers = request .headers .items (),
29+ body = request .body ,
30+ url = str (request .url ),
31+ )
9332
94- # Compose URL
95- if not kw .get ("params" ):
96- req .url = str (full_url )
97- else :
98- # Transform params as a list of tuple
99- params = kw ["params" ]
100- if isinstance (params , dict ):
101- params = [(x , y ) for x , y in kw ["params" ].items ()]
102- req .url = str (full_url ) + "?" + urlencode (params )
103-
104- # If a json payload is provided, serialize it for JSONMatcher support
105- if json_body := kw .get ("json" ):
106- req .json = json_body
107- if "Content-Type" not in req .headers :
108- req .headers ["Content-Type" ] = "application/json"
109-
110- # Match the request against the registered mocks in pook
11133 mock = self .engine .match (req )
11234
11335 # If cannot match any mock, run real HTTP request if networking
11436 # or silent model are enabled, otherwise this statement won't
11537 # be reached (an exception will be raised before).
11638 if not mock :
117- return await _request (
118- session , method , url , data = data , headers = headers , ** kw
119- )
120-
121- # Simulate network delay
122- if mock ._delay :
123- await asyncio .sleep (mock ._delay / 1000 ) # noqa
39+ return await handler (request )
12440
12541 # Shortcut to mock response
12642 res = mock ._response
@@ -131,7 +47,7 @@ async def _on_request(
13147 headers .append ((key , res ._headers [key ]))
13248
13349 # Create mock equivalent HTTP response
134- _res = HTTPResponse (session , req .method , self ._url (urlunparse (req .url )))
50+ _res = HTTPResponse (request . session , req .method , self ._url (urlunparse (req .url )))
13551
13652 # response status
13753 _res .version = aiohttp .HttpVersion (1 , 1 )
@@ -154,23 +70,24 @@ async def _on_request(
15470 # Return response based on mock definition
15571 return _res
15672
157- def _patch (self , path : str ) -> None :
73+ def _url (self , url ) -> Optional [yarl .URL ]:
74+ return yarl .URL (url ) if yarl else None
75+
76+ def activate (self ) -> None :
15877 # If not able to import aiohttp dependencies, skip
15978 if not yarl or not multidict :
16079 return None
16180
162- async def handler (session , method , url , data = None , headers = None , ** kw ):
163- return await self . _on_request (
164- _request , session , method , url , data = data , headers = headers , ** kw
165- )
81+ def _request (session , * args , ** kwargs ):
82+ request_middlewares = kwargs . get ( "middlewares" , ())
83+ kwargs [ "middlewares" ] = request_middlewares + ( self ,)
84+ return super_request ( session , * args , ** kwargs )
16685
16786 try :
168- # Create a new patcher for Urllib3 urlopen function
169- # used as entry point for all the HTTP communications
170- patcher = mock .patch (path , handler )
171- # Retrieve original patched function that we might need for real
172- # networking
173- _request = patcher .get_original ()[0 ]
87+ # Patch ClientSession init to append this interceptor as an aiohttp
88+ # middleware to all session's middlewares
89+ patcher = mock .patch ("aiohttp.client.ClientSession._request" , _request )
90+ super_request = patcher .get_original ()[0 ]
17491 # Start patching function calls
17592 patcher .start ()
17693 except Exception :
@@ -180,18 +97,33 @@ async def handler(session, method, url, data=None, headers=None, **kw):
18097 else :
18198 self .patchers .append (patcher )
18299
183- def activate (self ) -> None :
184- """
185- Activates the traffic interceptor.
186- This method must be implemented by any interceptor.
187- """
188- for path in PATCHES :
189- self ._patch (path )
190-
191100 def disable (self ) -> None :
192101 """
193102 Disables the traffic interceptor.
194103 This method must be implemented by any interceptor.
195104 """
196105 for patch in self .patchers :
197106 patch .stop ()
107+
108+
109+ class SimpleContent (EmptyStreamReader ):
110+ def __init__ (self , content , * args , ** kwargs ):
111+ super ().__init__ (* args , ** kwargs )
112+ self .content = content
113+
114+ async def read (self , n = - 1 ):
115+ return self .content
116+
117+
118+ def HTTPResponse (session : aiohttp .ClientSession , * args , ** kw ):
119+ return session ._response_class (
120+ * args ,
121+ request_info = mock .Mock (),
122+ writer = None ,
123+ continue100 = None ,
124+ timer = TimerNoop (),
125+ traces = [],
126+ loop = mock .Mock (),
127+ session = mock .Mock (),
128+ ** kw ,
129+ )
0 commit comments