Skip to content

Commit e06237c

Browse files
committed
Handle single rpc error responses for batch requests:
- With some clients, when executing a batch request, the client will return a single error response for the entire batch. Previously, we were assuming that batch requests always received a list of responses. This commit updates the batch request logic to handle this case. - Add tests for batch requests with a single error response.
1 parent 0d44d22 commit e06237c

File tree

13 files changed

+181
-28
lines changed

13 files changed

+181
-28
lines changed

newsfragments/3585.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Handle the case when a single RPC error response is returned for a batch request, instead of always expecting a list of responses.

tests/core/providers/test_async_http_provider.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
import pytest
2+
from unittest.mock import (
3+
AsyncMock,
4+
patch,
5+
)
26

37
from aiohttp import (
48
ClientSession,
@@ -13,6 +17,7 @@
1317
)
1418
from web3.exceptions import (
1519
ProviderConnectionError,
20+
Web3RPCError,
1621
)
1722
from web3.geth import (
1823
AsyncGeth,
@@ -107,3 +112,21 @@ def test_get_request_headers(provider):
107112
headers["User-Agent"] == f"web3.py/{web3py_version}/"
108113
f"{AsyncHTTPProvider.__module__}.{AsyncHTTPProvider.__qualname__}"
109114
)
115+
116+
117+
@patch(
118+
"web3._utils.http_session_manager.HTTPSessionManager.async_make_post_request",
119+
new_callable=AsyncMock,
120+
)
121+
@pytest.mark.asyncio
122+
async def test_async_http_empty_batch_response(mock_async_post):
123+
mock_async_post.return_value = (
124+
b'{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"empty batch"}}'
125+
)
126+
async_w3 = AsyncWeb3(AsyncHTTPProvider())
127+
async with async_w3.batch_requests() as batch:
128+
with pytest.raises(Web3RPCError, match="empty batch"):
129+
await batch.async_execute()
130+
131+
# assert that event though there was an error, we have reset the batching state
132+
assert not async_w3.provider._is_batching

tests/core/providers/test_async_ipc_provider.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626
from web3.exceptions import (
2727
ReadBufferLimitReached,
28+
Web3RPCError,
2829
)
2930
from web3.providers import (
3031
AsyncIPCProvider,
@@ -366,3 +367,24 @@ async def test_async_ipc_provider_write_messages_end_with_new_line_delimiter(
366367

367368
request_data = b'{"id": 0, "jsonrpc": "2.0", "method": "method", "params": []}'
368369
w3.provider._writer.write.assert_called_with(request_data + b"\n")
370+
371+
372+
@pytest.mark.asyncio
373+
async def test_persistent_connection_provider_empty_batch_response(
374+
simple_ipc_server,
375+
jsonrpc_ipc_pipe_path,
376+
):
377+
async with AsyncWeb3(
378+
AsyncIPCProvider(pathlib.Path(jsonrpc_ipc_pipe_path))
379+
) as async_w3:
380+
async_w3.provider._reader.readline = AsyncMock()
381+
async_w3.provider._reader.readline.return_value = (
382+
b'{"jsonrpc": "2.0","id":null,"error": {"code": -32600, "message": '
383+
b'"empty batch"}}\n'
384+
)
385+
async with async_w3.batch_requests() as batch:
386+
with pytest.raises(Web3RPCError, match="empty batch"):
387+
await batch.async_execute()
388+
389+
# assert that event though there was an error, we have reset the batching state
390+
assert not async_w3.provider._is_batching

tests/core/providers/test_http_provider.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
import pytest
2+
from unittest.mock import (
3+
Mock,
4+
patch,
5+
)
26

37
from requests import (
48
Session,
@@ -16,6 +20,7 @@
1620
)
1721
from web3.exceptions import (
1822
ProviderConnectionError,
23+
Web3RPCError,
1924
)
2025
from web3.geth import (
2126
Geth,
@@ -110,3 +115,20 @@ def test_get_request_headers(provider):
110115
headers["User-Agent"] == f"web3.py/{web3py_version}/"
111116
f"{HTTPProvider.__module__}.{HTTPProvider.__qualname__}"
112117
)
118+
119+
120+
@patch(
121+
"web3._utils.http_session_manager.HTTPSessionManager.make_post_request",
122+
new_callable=Mock,
123+
)
124+
def test_http_empty_batch_response(mock_post):
125+
mock_post.return_value = (
126+
b'{"jsonrpc":"2.0","id":null,"error":{"code":-32600,"message":"empty batch"}}'
127+
)
128+
w3 = Web3(HTTPProvider())
129+
with w3.batch_requests() as batch:
130+
with pytest.raises(Web3RPCError, match="empty batch"):
131+
batch.execute()
132+
133+
# assert that event though there was an error, we have reset the batching state
134+
assert not w3.provider._is_batching

tests/core/providers/test_websocket_provider.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
)
2828
from web3.exceptions import (
2929
TimeExhausted,
30+
Web3RPCError,
3031
)
3132
from web3.providers.persistent import (
3233
WebSocketProvider,
@@ -432,3 +433,24 @@ async def test_listener_task_breaks_out_of_handle_subscriptions_when_cancelled()
432433
# up the exception.
433434
with pytest.raises(asyncio.CancelledError):
434435
await async_w3.subscription_manager.handle_subscriptions(run_forever=True)
436+
437+
438+
@pytest.mark.asyncio
439+
async def test_persistent_connection_provider_empty_batch_response():
440+
with patch(
441+
"web3.providers.persistent.websocket.connect",
442+
new=lambda *_1, **_2: _mocked_ws_conn(),
443+
):
444+
async with AsyncWeb3(WebSocketProvider("ws://mocked")) as async_w3:
445+
async_w3.provider._ws.recv = AsyncMock()
446+
async_w3.provider._ws.recv.return_value = (
447+
b'{"jsonrpc": "2.0","id":null,"error": {"code": -32600, "message": '
448+
b'"empty batch"}}\n'
449+
)
450+
async with async_w3.batch_requests() as batch:
451+
with pytest.raises(Web3RPCError, match="empty batch"):
452+
await batch.async_execute()
453+
454+
# assert that event though there was an error, we have reset the batching
455+
# state
456+
assert not async_w3.provider._is_batching

web3/manager.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -436,15 +436,27 @@ def _make_batch_request(
436436
request_func = provider.batch_request_func(
437437
cast("Web3", self.w3), cast("MiddlewareOnion", self.middleware_onion)
438438
)
439-
responses = request_func(
439+
response = request_func(
440440
[
441441
(method, params)
442442
for (method, params), _response_formatters in requests_info
443443
]
444444
)
445+
446+
if not isinstance(response, list):
447+
# Expect a JSON-RPC error response and call _validate_response to raise
448+
# the appropriate exception
449+
_validate_response(
450+
cast(RPCResponse, response),
451+
None,
452+
is_subscription_response=False,
453+
logger=self.logger,
454+
params=[],
455+
)
456+
445457
formatted_responses = [
446-
self._format_batched_response(info, resp)
447-
for info, resp in zip(requests_info, responses)
458+
self._format_batched_response(info, cast(RPCResponse, resp))
459+
for info, resp in zip(requests_info, response)
448460
]
449461
return list(formatted_responses)
450462

@@ -465,23 +477,35 @@ async def _async_make_batch_request(
465477
# since we add items to the batch without awaiting, we unpack the coroutines
466478
# and await them all here
467479
unpacked_requests_info = await asyncio.gather(*requests_info)
468-
responses = await request_func(
480+
response = await request_func(
469481
[
470482
(method, params)
471483
for (method, params), _response_formatters in unpacked_requests_info
472484
]
473485
)
474486

487+
if not isinstance(response, list):
488+
# Expect a JSON-RPC error response and call _validate_response to raise
489+
# the appropriate exception
490+
_validate_response(
491+
cast(RPCResponse, response),
492+
None,
493+
is_subscription_response=False,
494+
logger=self.logger,
495+
params=[],
496+
)
497+
498+
response = cast(List[RPCResponse], response)
475499
if isinstance(self.provider, PersistentConnectionProvider):
476500
# call _process_response for each response in the batch
477501
return [
478502
cast(RPCResponse, await self._process_response(resp))
479-
for resp in responses
503+
for resp in response
480504
]
481505

482506
formatted_responses = [
483507
self._format_batched_response(info, resp)
484-
for info, resp in zip(unpacked_requests_info, responses)
508+
for info, resp in zip(unpacked_requests_info, response)
485509
]
486510
return list(formatted_responses)
487511

@@ -491,6 +515,13 @@ def _format_batched_response(
491515
response: RPCResponse,
492516
) -> RPCResponse:
493517
result_formatters, error_formatters, null_result_formatters = requests_info[1]
518+
_validate_response(
519+
response,
520+
error_formatters,
521+
is_subscription_response=False,
522+
logger=self.logger,
523+
params=requests_info[0][1],
524+
)
494525
return apply_result_formatters(
495526
result_formatters,
496527
self.formatted_response(

web3/middleware/base.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,19 @@ def wrap_make_batch_request(
6262
) -> "MakeBatchRequestFn":
6363
def middleware(
6464
requests_info: List[Tuple["RPCEndpoint", Any]]
65-
) -> List["RPCResponse"]:
65+
) -> Union[List["RPCResponse"], "RPCResponse"]:
6666
req_processed = [
6767
self.request_processor(method, params)
6868
for (method, params) in requests_info
6969
]
70-
responses = make_batch_request(req_processed)
70+
response = make_batch_request(req_processed)
71+
if not isinstance(response, list):
72+
# RPC errors return only one response with the error object
73+
return response
74+
7175
methods, _params = zip(*req_processed)
7276
formatted_responses = [
73-
self.response_processor(m, r) for m, r in zip(methods, responses)
77+
self.response_processor(m, r) for m, r in zip(methods, response)
7478
]
7579
return formatted_responses
7680

@@ -103,16 +107,20 @@ async def async_wrap_make_batch_request(
103107
) -> "AsyncMakeBatchRequestFn":
104108
async def middleware(
105109
requests_info: List[Tuple["RPCEndpoint", Any]]
106-
) -> List["RPCResponse"]:
110+
) -> Union[List["RPCResponse"], "RPCResponse"]:
107111
req_processed = [
108112
await self.async_request_processor(method, params)
109113
for (method, params) in requests_info
110114
]
111-
responses = await make_batch_request(req_processed)
115+
response = await make_batch_request(req_processed)
116+
if not isinstance(response, list):
117+
# RPC errors return only one response with the error object
118+
return response
119+
112120
methods, _params = zip(*req_processed)
113121
formatted_responses = [
114122
await self.async_response_processor(m, r)
115-
for m, r in zip(methods, responses)
123+
for m, r in zip(methods, response)
116124
]
117125
return formatted_responses
118126

web3/providers/async_base.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ class AsyncBaseProvider:
7777

7878
_is_batching: bool = False
7979
_batch_request_func_cache: Tuple[
80-
Tuple[Middleware, ...], Callable[..., Coroutine[Any, Any, List[RPCResponse]]]
80+
Tuple[Middleware, ...],
81+
Callable[..., Coroutine[Any, Any, Union[List[RPCResponse], RPCResponse]]],
8182
] = (None, None)
8283

8384
is_async = True
@@ -119,7 +120,7 @@ async def request_func(
119120

120121
async def batch_request_func(
121122
self, async_w3: "AsyncWeb3", middleware_onion: MiddlewareOnion
122-
) -> Callable[..., Coroutine[Any, Any, List[RPCResponse]]]:
123+
) -> Callable[..., Coroutine[Any, Any, Union[List[RPCResponse], RPCResponse]]]:
123124
middleware: Tuple[Middleware, ...] = middleware_onion.as_tuple_of_middleware()
124125

125126
cache_key = self._batch_request_func_cache[0]
@@ -141,8 +142,8 @@ async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:
141142

142143
async def make_batch_request(
143144
self, requests: List[Tuple[RPCEndpoint, Any]]
144-
) -> List[RPCResponse]:
145-
raise NotImplementedError("Only AsyncHTTPProvider supports this method")
145+
) -> Union[List[RPCResponse], RPCResponse]:
146+
raise NotImplementedError("Providers must implement this method")
146147

147148
async def is_connected(self, show_traceback: bool = False) -> bool:
148149
raise NotImplementedError("Providers must implement this method")

web3/providers/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ class JSONBaseProvider(BaseProvider):
118118

119119
_is_batching: bool = False
120120
_batch_request_func_cache: Tuple[
121-
Tuple[Middleware, ...], Callable[..., List[RPCResponse]]
121+
Tuple[Middleware, ...], Callable[..., Union[List[RPCResponse], RPCResponse]]
122122
] = (None, None)
123123

124124
def __init__(self, **kwargs: Any) -> None:
@@ -168,7 +168,7 @@ def is_connected(self, show_traceback: bool = False) -> bool:
168168

169169
def batch_request_func(
170170
self, w3: "Web3", middleware_onion: MiddlewareOnion
171-
) -> Callable[..., List[RPCResponse]]:
171+
) -> Callable[..., Union[List[RPCResponse], RPCResponse]]:
172172
middleware: Tuple[Middleware, ...] = middleware_onion.as_tuple_of_middleware()
173173

174174
cache_key = self._batch_request_func_cache[0]
@@ -199,5 +199,5 @@ def encode_batch_rpc_request(
199199

200200
def make_batch_request(
201201
self, requests: List[Tuple[RPCEndpoint, Any]]
202-
) -> List[RPCResponse]:
202+
) -> Union[List[RPCResponse], RPCResponse]:
203203
raise NotImplementedError("Providers must implement this method")

web3/providers/persistent/request_processor.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
Callable,
77
Dict,
88
Generic,
9+
List,
910
Optional,
1011
Tuple,
1112
TypeVar,
@@ -270,6 +271,15 @@ def append_middleware_response_processor(
270271

271272
# raw response cache
272273

274+
def _is_batch_response(
275+
self, raw_response: Union[List[RPCResponse], RPCResponse]
276+
) -> bool:
277+
return isinstance(raw_response, list) or (
278+
isinstance(raw_response, dict)
279+
and raw_response.get("id") is None
280+
and self._provider._is_batching
281+
)
282+
273283
async def cache_raw_response(
274284
self, raw_response: Any, subscription: bool = False
275285
) -> None:
@@ -296,7 +306,7 @@ async def cache_raw_response(
296306
# otherwise, put it in the subscription response queue so a response
297307
# can be yielded by the message stream
298308
await self._subscription_response_queue.put(raw_response)
299-
elif isinstance(raw_response, list):
309+
elif self._is_batch_response(raw_response):
300310
# Since only one batch should be in the cache at all times, we use a
301311
# constant cache key for the batch response.
302312
cache_key = generate_cache_key(BATCH_REQUEST_ID)

0 commit comments

Comments
 (0)