10
10
cast ,
11
11
)
12
12
13
- from toolz import (
13
+ from eth_utils . toolz import (
14
14
merge ,
15
15
)
16
16
17
+ from web3 .providers .persistent import (
18
+ PersistentConnectionProvider ,
19
+ )
20
+
17
21
if TYPE_CHECKING :
18
22
from web3 import ( # noqa: F401
19
23
AsyncWeb3 ,
26
30
AsyncMakeRequestFn ,
27
31
MakeRequestFn ,
28
32
RPCEndpoint ,
33
+ RPCRequest ,
29
34
RPCResponse ,
30
35
)
31
36
@@ -101,9 +106,21 @@ def __init__(
101
106
self .mock_results = mock_results or {}
102
107
self .mock_errors = mock_errors or {}
103
108
self .mock_responses = mock_responses or {}
104
- self ._make_request : Union [
105
- "AsyncMakeRequestFn" , "MakeRequestFn"
106
- ] = w3 .provider .make_request
109
+ if isinstance (w3 .provider , PersistentConnectionProvider ):
110
+ self ._send_request = w3 .provider .send_request
111
+ self ._recv_for_request = w3 .provider .recv_for_request
112
+ else :
113
+ self ._make_request : Union [
114
+ "AsyncMakeRequestFn" , "MakeRequestFn"
115
+ ] = w3 .provider .make_request
116
+
117
+ def _build_request_id (self ) -> int :
118
+ request_id = (
119
+ next (copy .deepcopy (self .w3 .provider .request_counter ))
120
+ if hasattr (self .w3 .provider , "request_counter" )
121
+ else 1
122
+ )
123
+ return request_id
107
124
108
125
def __enter__ (self ) -> "Self" :
109
126
# mypy error: Cannot assign to a method
@@ -131,11 +148,7 @@ def _mock_request_handler(
131
148
):
132
149
return self ._make_request (method , params )
133
150
134
- request_id = (
135
- next (copy .deepcopy (self .w3 .provider .request_counter ))
136
- if hasattr (self .w3 .provider , "request_counter" )
137
- else 1
138
- )
151
+ request_id = self ._build_request_id ()
139
152
response_dict = {"jsonrpc" : "2.0" , "id" : request_id }
140
153
141
154
if method in self .mock_responses :
@@ -176,35 +189,34 @@ def _mock_request_handler(
176
189
# -- async -- #
177
190
178
191
async def __aenter__ (self ) -> "Self" :
179
- # mypy error: Cannot assign to a method
180
- self .w3 .provider .make_request = self ._async_mock_request_handler # type: ignore[method-assign] # noqa: E501
181
- # reset request func cache to re-build request_func with mocked make_request
182
- self .w3 .provider ._request_func_cache = (None , None )
192
+ if not isinstance (self .w3 .provider , PersistentConnectionProvider ):
193
+ # mypy error: Cannot assign to a method
194
+ self .w3 .provider .make_request = self ._async_mock_request_handler # type: ignore[method-assign] # noqa: E501
195
+ # reset request func cache to re-build request_func w/ mocked make_request
196
+ self .w3 .provider ._request_func_cache = (None , None )
197
+ else :
198
+ self .w3 .provider .send_request = self ._async_mock_send_handler # type: ignore[method-assign] # noqa: E501
199
+ self .w3 .provider .recv_for_request = self ._async_mock_recv_handler # type: ignore[method-assign] # noqa: E501
200
+ self .w3 .provider ._send_func_cache = (None , None )
201
+ self .w3 .provider ._recv_func_cache = (None , None )
183
202
return self
184
203
185
204
async def __aexit__ (self , exc_type : Any , exc_value : Any , traceback : Any ) -> None :
186
- # mypy error: Cannot assign to a method
187
- self .w3 .provider .make_request = self ._make_request # type: ignore[assignment]
188
- # reset request func cache to re-build request_func with original make_request
189
- self .w3 .provider ._request_func_cache = (None , None )
205
+ if not isinstance (self .w3 .provider , PersistentConnectionProvider ):
206
+ # mypy error: Cannot assign to a method
207
+ self .w3 .provider .make_request = self ._make_request # type: ignore[assignment] # noqa: E501
208
+ # reset request func cache to re-build request_func w/ original make_request
209
+ self .w3 .provider ._request_func_cache = (None , None )
210
+ else :
211
+ self .w3 .provider .send_request = self ._send_request # type: ignore[method-assign] # noqa: E501
212
+ self .w3 .provider .recv_for_request = self ._recv_for_request # type: ignore[method-assign] # noqa: E501
213
+ self .w3 .provider ._send_func_cache = (None , None )
214
+ self .w3 .provider ._recv_func_cache = (None , None )
190
215
191
- async def _async_mock_request_handler (
192
- self , method : "RPCEndpoint" , params : Any
216
+ async def _async_build_mock_result (
217
+ self , method : "RPCEndpoint" , params : Any , request_id : int = None
193
218
) -> "RPCResponse" :
194
- self .w3 = cast ("AsyncWeb3" , self .w3 )
195
- self ._make_request = cast ("AsyncMakeRequestFn" , self ._make_request )
196
-
197
- if all (
198
- method not in mock_dict
199
- for mock_dict in (self .mock_errors , self .mock_results , self .mock_responses )
200
- ):
201
- return await self ._make_request (method , params )
202
-
203
- request_id = (
204
- next (copy .deepcopy (self .w3 .provider .request_counter ))
205
- if hasattr (self .w3 .provider , "request_counter" )
206
- else 1
207
- )
219
+ request_id = request_id if request_id else self ._build_request_id ()
208
220
response_dict = {"jsonrpc" : "2.0" , "id" : request_id }
209
221
210
222
if method in self .mock_responses :
@@ -244,6 +256,19 @@ async def _async_mock_request_handler(
244
256
else :
245
257
raise Exception ("Invariant: unreachable code path" )
246
258
259
+ return mocked_result
260
+
261
+ async def _async_mock_request_handler (
262
+ self , method : "RPCEndpoint" , params : Any
263
+ ) -> "RPCResponse" :
264
+ self .w3 = cast ("AsyncWeb3" , self .w3 )
265
+ self ._make_request = cast ("AsyncMakeRequestFn" , self ._make_request )
266
+ if all (
267
+ method not in mock_dict
268
+ for mock_dict in (self .mock_errors , self .mock_results , self .mock_responses )
269
+ ):
270
+ return await self ._make_request (method , params )
271
+ mocked_result = await self ._async_build_mock_result (method , params )
247
272
decorator = getattr (self ._make_request , "_decorator" , None )
248
273
if decorator is not None :
249
274
# If the original make_request was decorated, we need to re-apply
@@ -259,6 +284,47 @@ async def _coro(
259
284
else :
260
285
return mocked_result
261
286
287
+ async def _async_mock_send_handler (
288
+ self , method : "RPCEndpoint" , params : Any
289
+ ) -> "RPCRequest" :
290
+ if all (
291
+ method not in mock_dict
292
+ for mock_dict in (self .mock_errors , self .mock_results , self .mock_responses )
293
+ ):
294
+ return await self ._send_request (method , params )
295
+ else :
296
+ request_id = self ._build_request_id ()
297
+ return {"id" : request_id , "method" : method , "params" : params }
298
+
299
+ async def _async_mock_recv_handler (
300
+ self , rpc_request : "RPCRequest"
301
+ ) -> "RPCResponse" :
302
+ self .w3 = cast ("AsyncWeb3" , self .w3 )
303
+ method = rpc_request ["method" ]
304
+ request_id = rpc_request ["id" ]
305
+ if all (
306
+ method not in mock_dict
307
+ for mock_dict in (self .mock_errors , self .mock_results , self .mock_responses )
308
+ ):
309
+ return await self ._recv_for_request (request_id )
310
+ mocked_result = await self ._async_build_mock_result (
311
+ method , rpc_request ["params" ], request_id = int (request_id )
312
+ )
313
+ decorator = getattr (self ._recv_for_request , "_decorator" , None )
314
+ if decorator is not None :
315
+ # If the original recv_for_request was decorated, we need to re-apply
316
+ # the decorator to the mocked recv_for_request. This is necessary for
317
+ # the request caching decorator to work properly.
318
+
319
+ async def _coro (
320
+ _provider : Any , _rpc_request : "RPCRequest"
321
+ ) -> "RPCResponse" :
322
+ return mocked_result
323
+
324
+ return await decorator (_coro )(self .w3 .provider , rpc_request )
325
+ else :
326
+ return mocked_result
327
+
262
328
@staticmethod
263
329
def _create_error_object (error : Dict [str , Any ]) -> Dict [str , Any ]:
264
330
code = error .get ("code" , - 32000 )
0 commit comments