Skip to content

Commit ee602a3

Browse files
committed
wrap all make_batch_request in batching_context decorator to manage
batching status
1 parent 0405dcb commit ee602a3

File tree

11 files changed

+128
-8
lines changed

11 files changed

+128
-8
lines changed

newsfragments/3642.bugfix.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
Checks that ``PersistentConnectionProvider`` response cache value is a dict before attempting to access it like one.
1+
Checks that ``PersistentConnectionProvider`` response cache value is a dict before attempting to access it like one. Also adds checks to ``make_batch_request`` to make sure it is in batching mode before being called and is not after.

tests/core/providers/test_async_http_provider.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,21 @@ async def test_async_http_empty_batch_response(mock_async_post):
130130

131131
# assert that even though there was an error, we have reset the batching state
132132
assert not async_w3.provider._is_batching
133+
134+
135+
@patch(
136+
"web3._utils.http_session_manager.HTTPSessionManager.async_make_post_request",
137+
new_callable=AsyncMock,
138+
)
139+
@pytest.mark.asyncio
140+
async def test_async_provider_is_batching_when_make_batch_request(mock_post):
141+
def assert_is_batching_and_return_response(*_args, **_kwargs) -> bytes:
142+
assert provider._is_batching
143+
return b'{"jsonrpc":"2.0","id":1,"result":["0x1"]}'
144+
145+
mock_post.side_effect = assert_is_batching_and_return_response
146+
provider = AsyncHTTPProvider()
147+
148+
assert not provider._is_batching
149+
await provider.make_batch_request([("eth_blockNumber", [])])
150+
assert not provider._is_batching

tests/core/providers/test_http_provider.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,21 @@ def test_http_empty_batch_response(mock_post):
132132

133133
# assert that even though there was an error, we have reset the batching state
134134
assert not w3.provider._is_batching
135+
136+
137+
@patch(
138+
"web3._utils.http_session_manager.HTTPSessionManager.make_post_request",
139+
new_callable=Mock,
140+
)
141+
def test_sync_provider_is_batching_when_make_batch_request(mock_post):
142+
def assert_is_batching_and_return_response(*_args, **_kwargs) -> bytes:
143+
assert provider._is_batching
144+
return b'{"jsonrpc":"2.0","id":1,"result":["0x1"]}'
145+
146+
provider = HTTPProvider()
147+
assert not provider._is_batching
148+
149+
mock_post.side_effect = assert_is_batching_and_return_response
150+
151+
provider.make_batch_request([("eth_blockNumber", [])])
152+
assert not provider._is_batching

tests/core/providers/test_ipc_provider.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,18 @@ def test_ipc_provider_write_messages_end_with_new_line_delimiter(jsonrpc_ipc_pip
199199

200200
request_data = b'{"jsonrpc": "2.0", "method": "method", "params": [], "id": 0}'
201201
provider._socket.sock.sendall.assert_called_with(request_data + b"\n")
202+
203+
204+
def test_ipc_provider_is_batching_when_make_batch_request(jsonrpc_ipc_pipe_path):
205+
def assert_is_batching_and_return_response(*_args, **_kwargs) -> bytes:
206+
assert provider._is_batching
207+
return [{"id": 0, "jsonrpc": "2.0", "result": {}}]
208+
209+
provider = IPCProvider(pathlib.Path(jsonrpc_ipc_pipe_path), timeout=3)
210+
provider._make_request = Mock()
211+
provider._make_request.side_effect = assert_is_batching_and_return_response
212+
213+
assert not provider._is_batching
214+
215+
provider.make_batch_request([("eth_blockNumber", [])])
216+
assert not provider._is_batching

tests/core/providers/test_websocket_provider.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ async def test_req_info_cache_size_can_be_set_and_warns_when_full(caplog):
537537

538538

539539
@pytest.mark.asyncio
540-
async def test_raise_stray_errors_from_cache_handles_list_response():
540+
async def test_raise_stray_errors_from_cache_handles_list_response_without_error():
541541
provider = WebSocketProvider("ws://mocked")
542542
_mock_ws(provider)
543543

@@ -546,7 +546,23 @@ async def test_raise_stray_errors_from_cache_handles_list_response():
546546
]
547547
provider._request_processor._request_response_cache._data["bad_key"] = bad_response
548548

549-
try:
550-
provider._raise_stray_errors_from_cache()
551-
except Exception as e:
552-
pytest.fail(f"{e}")
549+
# assert no errors raised
550+
provider._raise_stray_errors_from_cache()
551+
552+
553+
@pytest.mark.asyncio
554+
async def test_websocket_provider_is_batching_when_make_batch_request():
555+
def assert_is_batching_and_return_response(*_args, **_kwargs) -> bytes:
556+
assert provider._is_batching
557+
return b'{"jsonrpc":"2.0","id":1,"result":["0x1"]}'
558+
559+
provider = WebSocketProvider("ws://mocked")
560+
_mock_ws(provider)
561+
provider._get_response_for_request_id = AsyncMock()
562+
provider._get_response_for_request_id.side_effect = (
563+
assert_is_batching_and_return_response
564+
)
565+
566+
assert not provider._is_batching
567+
await provider.make_batch_request([("eth_blockNumber", [])])
568+
assert not provider._is_batching

web3/_utils/batching.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from copy import (
22
copy,
33
)
4+
from functools import (
5+
wraps,
6+
)
47
from types import (
58
TracebackType,
69
)
@@ -12,9 +15,11 @@
1215
Dict,
1316
Generic,
1417
List,
18+
Protocol,
1519
Sequence,
1620
Tuple,
1721
Type,
22+
TypeVar,
1823
Union,
1924
cast,
2025
)
@@ -33,6 +38,7 @@
3338
Web3ValueError,
3439
)
3540
from web3.types import (
41+
RPCEndpoint,
3642
TFunc,
3743
TReturn,
3844
)
@@ -55,7 +61,6 @@
5561
JSONBaseProvider,
5662
)
5763
from web3.types import ( # noqa: F401
58-
RPCEndpoint,
5964
RPCResponse,
6065
)
6166

@@ -215,3 +220,39 @@ def sort_batch_response_by_response_ids(
215220
stacklevel=2,
216221
)
217222
return responses
223+
224+
225+
class SupportsBatching(Protocol):
226+
_is_batching: bool
227+
228+
229+
R = TypeVar("R")
230+
T = TypeVar("T", bound=SupportsBatching)
231+
232+
233+
def async_batching_context(
234+
method: Callable[[T, List[Tuple[RPCEndpoint, Any]]], Coroutine[Any, Any, R]]
235+
) -> Callable[[T, List[Tuple[RPCEndpoint, Any]]], Coroutine[Any, Any, R]]:
236+
@wraps(method)
237+
async def wrapper(self: T, requests: List[Tuple[RPCEndpoint, Any]]) -> R:
238+
self._is_batching = True
239+
try:
240+
return await method(self, requests)
241+
finally:
242+
self._is_batching = False
243+
244+
return wrapper
245+
246+
247+
def batching_context(
248+
method: Callable[[T, List[Tuple[RPCEndpoint, Any]]], R]
249+
) -> Callable[[T, List[Tuple[RPCEndpoint, Any]]], R]:
250+
@wraps(method)
251+
def wrapper(self: T, requests: List[Tuple[RPCEndpoint, Any]]) -> R:
252+
self._is_batching = True
253+
try:
254+
return method(self, requests)
255+
finally:
256+
self._is_batching = False
257+
258+
return wrapper

web3/providers/ipc.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
)
3131

3232
from .._utils.batching import (
33+
batching_context,
3334
sort_batch_response_by_response_ids,
3435
)
3536
from .._utils.caching import (
@@ -201,6 +202,7 @@ def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:
201202
request = self.encode_rpc_request(method, params)
202203
return self._make_request(request)
203204

205+
@batching_context
204206
def make_batch_request(
205207
self, requests: List[Tuple[RPCEndpoint, Any]]
206208
) -> List[RPCResponse]:

web3/providers/legacy_websocket.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
)
2828

2929
from web3._utils.batching import (
30+
batching_context,
3031
sort_batch_response_by_response_ids,
3132
)
3233
from web3._utils.caching import (
@@ -143,6 +144,7 @@ def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:
143144
)
144145
return future.result()
145146

147+
@batching_context
146148
def make_batch_request(
147149
self, requests: List[Tuple[RPCEndpoint, Any]]
148150
) -> List[RPCResponse]:

web3/providers/persistent/persistent.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from web3._utils.batching import (
2626
BATCH_REQUEST_ID,
27+
async_batching_context,
2728
sort_batch_response_by_response_ids,
2829
)
2930
from web3._utils.caching import (
@@ -237,14 +238,17 @@ async def make_request(
237238
rpc_request = await self.send_request(method, params)
238239
return await self.recv_for_request(rpc_request)
239240

241+
@async_batching_context
240242
async def make_batch_request(
241243
self, requests: List[Tuple[RPCEndpoint, Any]]
242244
) -> List[RPCResponse]:
243245
request_data = self.encode_batch_rpc_request(requests)
244246
await self.socket_send(request_data)
245247

248+
# breakpoint()
246249
response = cast(
247-
List[RPCResponse], await self._get_response_for_request_id(BATCH_REQUEST_ID)
250+
List[RPCResponse],
251+
await self._get_response_for_request_id(BATCH_REQUEST_ID),
248252
)
249253
return response
250254

web3/providers/rpc/async_rpc.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
)
3737

3838
from ..._utils.batching import (
39+
async_batching_context,
3940
sort_batch_response_by_response_ids,
4041
)
4142
from ..._utils.caching import (
@@ -166,6 +167,7 @@ async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:
166167
)
167168
return response
168169

170+
@async_batching_context
169171
async def make_batch_request(
170172
self, batch_requests: List[Tuple[RPCEndpoint, Any]]
171173
) -> Union[List[RPCResponse], RPCResponse]:

0 commit comments

Comments
 (0)