Skip to content

Commit 39584f6

Browse files
committed
Refactor shared logic for persistent cxn providers:
- Refactor shared logic for persistent connection providers back into the `PersistentConnectionProvider` base class. - Refactor reading from and sending to the socket into their own methods. Add these as abstract methods on the base class and force subclasses to implement them. Important: - Refactor the public `socket` API (`PersistentConnection` class) to properly define a `send()`, `recv()`, and `make_request()` which do what they say they do. Previously, `send()` was actually a "send_and_recv()" method. This was confusing and not a good API design. Note, these methods don't and cannot process the response in the way that invoking a method via its module within the web3 class does.
1 parent acc6de2 commit 39584f6

File tree

11 files changed

+266
-171
lines changed

11 files changed

+266
-171
lines changed

newsfragments/3433.breaking.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Refactor the public ``socket`` api for persistent connection providers to properly define ``send()``, ``recv()``, and ``make_request()`` (send and wait for response) methods for interacting with the open socket.

newsfragments/3433.feature.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add ``popitem()`` functionality to the ``SimpleCache`` class as well as an async utility method to wait for the next item, ``async_await_and_popitem()``.

newsfragments/3433.internal.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Refactor some common logic for persistent connection providers back into the base ``PersistentConnectionProvider`` class to reduce code duplication and improve maintainability.

web3/_utils/module_testing/persistent_connection_provider.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
)
2424
from web3.types import (
2525
FormattedEthSubscriptionResponse,
26+
RPCEndpoint,
2627
)
2728

2829
if TYPE_CHECKING:
@@ -31,6 +32,22 @@
3132
)
3233

3334

35+
SOME_BLOCK_KEYS = [
36+
"number",
37+
"hash",
38+
"parentHash",
39+
"transactionsRoot",
40+
"stateRoot",
41+
"receiptsRoot",
42+
"size",
43+
"gasLimit",
44+
"gasUsed",
45+
"timestamp",
46+
"transactions",
47+
"baseFeePerGas",
48+
]
49+
50+
3451
class PersistentConnectionProviderTest:
3552
@pytest.mark.asyncio
3653
@pytest.mark.parametrize(
@@ -346,11 +363,40 @@ async def test_async_extradata_poa_middleware_on_eth_subscription(
346363
# clean up
347364
async_w3.middleware_onion.remove("poa_middleware")
348365

366+
@pytest.mark.asyncio
367+
async def test_public_socket_api(self, async_w3: "AsyncWeb3") -> None:
368+
async_w3.provider._request_processor.clear_caches()
369+
370+
# send a request over the socket
371+
await async_w3.socket.send(
372+
RPCEndpoint("eth_getBlockByNumber"), ["latest", True]
373+
)
374+
375+
# recv and validate the unprocessed response
376+
response = await async_w3.socket.recv()
377+
assert "id" in response, "Expected 'id' key in response."
378+
assert "jsonrpc" in response, "Expected 'jsonrpc' key in response."
379+
assert "result" in response, "Expected 'result' key in response."
380+
assert all(k in response["result"].keys() for k in SOME_BLOCK_KEYS)
381+
assert not isinstance(response["result"]["number"], int) # assert not processed
382+
383+
# make a request over the socket
384+
response = await async_w3.socket.make_request(
385+
RPCEndpoint("eth_getBlockByNumber"), ["latest", True]
386+
)
387+
assert "id" in response, "Expected 'id' key in response."
388+
assert "jsonrpc" in response, "Expected 'jsonrpc' key in response."
389+
assert "result" in response, "Expected 'result' key in response."
390+
assert all(k in response["result"].keys() for k in SOME_BLOCK_KEYS)
391+
assert not isinstance(response["result"]["number"], int) # assert not processed
392+
349393
@pytest.mark.asyncio
350394
async def test_asyncio_gather_for_multiple_requests_matches_the_responses(
351395
self,
352396
async_w3: "AsyncWeb3",
353397
) -> None:
398+
async_w3.provider._request_processor.clear_caches()
399+
354400
(
355401
latest,
356402
chain_id,
@@ -372,22 +418,8 @@ async def test_asyncio_gather_for_multiple_requests_matches_the_responses(
372418
assert isinstance(pending, AttributeDict)
373419

374420
# assert block values
375-
some_block_keys = [
376-
"number",
377-
"hash",
378-
"parentHash",
379-
"transactionsRoot",
380-
"stateRoot",
381-
"receiptsRoot",
382-
"size",
383-
"gasLimit",
384-
"gasUsed",
385-
"timestamp",
386-
"transactions",
387-
"baseFeePerGas",
388-
]
389-
assert all(k in latest.keys() for k in some_block_keys)
390-
assert all(k in pending.keys() for k in some_block_keys)
421+
assert all(k in latest.keys() for k in SOME_BLOCK_KEYS)
422+
assert all(k in pending.keys() for k in SOME_BLOCK_KEYS)
391423

392424
assert isinstance(block_num, int)
393425
assert latest["number"] == block_num

web3/manager.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -469,22 +469,50 @@ def _format_batched_response(
469469

470470
# -- persistent connection -- #
471471

472-
async def send(self, method: RPCEndpoint, params: Any) -> RPCResponse:
472+
async def socket_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:
473473
provider = cast(PersistentConnectionProvider, self._provider)
474474
request_func = await provider.request_func(
475475
cast("AsyncWeb3", self.w3), cast("MiddlewareOnion", self.middleware_onion)
476476
)
477477
self.logger.debug(
478-
"Making request to open socket connection: "
478+
"Making request to open socket connection and waiting for response: "
479479
f"{provider.get_endpoint_uri_or_ipc_path()}, method: {method}"
480480
)
481481
response = await request_func(method, params)
482482
return await self._process_response(response)
483483

484+
async def send(self, method: RPCEndpoint, params: Any) -> None:
485+
provider = cast(PersistentConnectionProvider, self._provider)
486+
# run through the request processors of the middleware
487+
for mw_class in self.middleware_onion.as_tuple_of_middleware():
488+
mw = mw_class(self.w3)
489+
method, params = mw.request_processor(method, params)
490+
491+
self.logger.debug(
492+
"Sending request to open socket connection: "
493+
f"{provider.get_endpoint_uri_or_ipc_path()}, method: {method}"
494+
)
495+
await provider.socket_send(provider.encode_rpc_request(method, params))
496+
497+
async def recv(self) -> RPCResponse:
498+
provider = cast(PersistentConnectionProvider, self._provider)
499+
self.logger.debug(
500+
"Getting next response from open socket connection: "
501+
f"{provider.get_endpoint_uri_or_ipc_path()}"
502+
)
503+
# pop from the queue since the listener task is responsible for reading
504+
# directly from the socket
505+
request_response_cache = self._request_processor._request_response_cache
506+
_key, response = await request_response_cache.async_await_and_popitem(
507+
last=False,
508+
timeout=provider.request_timeout,
509+
)
510+
return await self._process_response(response)
511+
484512
def _persistent_message_stream(self) -> "_AsyncPersistentMessageStream":
485513
return _AsyncPersistentMessageStream(self)
486514

487-
async def _get_next_message(self) -> Optional[RPCResponse]:
515+
async def _get_next_message(self) -> RPCResponse:
488516
return await self._message_stream().__anext__()
489517

490518
async def _message_stream(self) -> AsyncGenerator[RPCResponse, None]:

web3/module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ async def caller(*args: Any, **kwargs: Any) -> Union[RPCResponse, AsyncLogFilter
138138

139139
try:
140140
method_str = cast(RPCEndpoint, method_str)
141-
return await async_w3.manager.send(method_str, params)
141+
return await async_w3.manager.socket_request(method_str, params)
142142
except Exception as e:
143143
if (
144144
cache_key is not None

web3/providers/persistent/async_ipc.py

Lines changed: 34 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -11,32 +11,22 @@
1111
import sys
1212
from typing import (
1313
Any,
14-
List,
1514
Optional,
1615
Tuple,
1716
Union,
18-
cast,
1917
)
2018

2119
from eth_utils import (
2220
to_text,
2321
)
2422

2523
from web3.types import (
26-
RPCEndpoint,
2724
RPCResponse,
2825
)
2926

3027
from . import (
3128
PersistentConnectionProvider,
3229
)
33-
from ..._utils.batching import (
34-
BATCH_REQUEST_ID,
35-
sort_batch_response_by_response_ids,
36-
)
37-
from ..._utils.caching import (
38-
async_handle_request_caching,
39-
)
4030
from ...exceptions import (
4131
ProviderConnectionError,
4232
Web3TypeError,
@@ -91,12 +81,7 @@ async def is_connected(self, show_traceback: bool = False) -> bool:
9181
return False
9282

9383
try:
94-
request_data = self.encode_rpc_request(
95-
RPCEndpoint("web3_clientVersions"), []
96-
)
97-
self._writer.write(request_data)
98-
current_request_id = json.loads(request_data)["id"]
99-
await self._get_response_for_request_id(current_request_id, timeout=2)
84+
await self.make_request("web3_clientVersion", [])
10085
return True
10186
except (OSError, ProviderConnectionError) as e:
10287
if show_traceback:
@@ -105,55 +90,33 @@ async def is_connected(self, show_traceback: bool = False) -> bool:
10590
)
10691
return False
10792

108-
async def _provider_specific_connect(self) -> None:
109-
self._reader, self._writer = await async_get_ipc_socket(self.ipc_path)
110-
111-
async def _provider_specific_disconnect(self) -> None:
112-
if self._writer and not self._writer.is_closing():
113-
self._writer.close()
114-
await self._writer.wait_closed()
115-
self._writer = None
116-
if self._reader:
117-
self._reader = None
118-
119-
async def _reset_socket(self) -> None:
120-
self._writer.close()
121-
await self._writer.wait_closed()
122-
self._reader, self._writer = await async_get_ipc_socket(self.ipc_path)
123-
124-
@async_handle_request_caching
125-
async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse:
93+
async def socket_send(self, request_data: bytes) -> None:
12694
if self._writer is None:
12795
raise ProviderConnectionError(
12896
"Connection to ipc socket has not been initiated for the provider."
12997
)
13098

131-
request_data = self.encode_rpc_request(method, params)
132-
try:
133-
self._writer.write(request_data)
134-
await self._writer.drain()
135-
except OSError as e:
136-
# Broken pipe
137-
if e.errno == errno.EPIPE:
138-
# one extra attempt, then give up
139-
await self._reset_socket()
140-
self._writer.write(request_data)
141-
await self._writer.drain()
99+
return await asyncio.wait_for(
100+
self._socket_send(request_data), timeout=self.request_timeout
101+
)
142102

143-
current_request_id = json.loads(request_data)["id"]
144-
response = await self._get_response_for_request_id(current_request_id)
103+
async def socket_recv(self) -> RPCResponse:
104+
while True:
105+
# yield to the event loop to allow other tasks to run
106+
await asyncio.sleep(0)
145107

146-
return response
108+
try:
109+
response, pos = self._decoder.raw_decode(self._raw_message)
110+
self._raw_message = self._raw_message[pos:].lstrip()
111+
return response
112+
except JSONDecodeError:
113+
# read more data from the socket if the current raw message is
114+
# incomplete
115+
self._raw_message += to_text(await self._reader.read(4096)).lstrip()
147116

148-
async def make_batch_request(
149-
self, requests: List[Tuple[RPCEndpoint, Any]]
150-
) -> List[RPCResponse]:
151-
if self._writer is None:
152-
raise ProviderConnectionError(
153-
"Connection to ipc socket has not been initiated for the provider."
154-
)
117+
# -- private methods -- #
155118

156-
request_data = self.encode_batch_rpc_request(requests)
119+
async def _socket_send(self, request_data: bytes) -> None:
157120
try:
158121
self._writer.write(request_data)
159122
await self._writer.drain()
@@ -165,32 +128,24 @@ async def make_batch_request(
165128
self._writer.write(request_data)
166129
await self._writer.drain()
167130

168-
response = cast(
169-
List[RPCResponse], await self._get_response_for_request_id(BATCH_REQUEST_ID)
170-
)
171-
return response
172-
173-
async def _provider_specific_message_listener(self) -> None:
174-
self._raw_message += to_text(await self._reader.read(4096)).lstrip()
131+
async def _reset_socket(self) -> None:
132+
self._writer.close()
133+
await self._writer.wait_closed()
134+
self._reader, self._writer = await async_get_ipc_socket(self.ipc_path)
175135

176-
while self._raw_message:
177-
try:
178-
response, pos = self._decoder.raw_decode(self._raw_message)
179-
except JSONDecodeError:
180-
break
136+
async def _provider_specific_connect(self) -> None:
137+
self._reader, self._writer = await async_get_ipc_socket(self.ipc_path)
181138

182-
if isinstance(response, list):
183-
response = sort_batch_response_by_response_ids(response)
139+
async def _provider_specific_disconnect(self) -> None:
140+
if self._writer and not self._writer.is_closing():
141+
self._writer.close()
142+
await self._writer.wait_closed()
143+
self._writer = None
144+
if self._reader:
145+
self._reader = None
184146

185-
is_subscription = (
186-
response.get("method") == "eth_subscription"
187-
if not isinstance(response, list)
188-
else False
189-
)
190-
await self._request_processor.cache_raw_response(
191-
response, subscription=is_subscription
192-
)
193-
self._raw_message = self._raw_message[pos:].lstrip()
147+
async def _provider_specific_socket_reader(self) -> RPCResponse:
148+
return await self.socket_recv()
194149

195150
def _error_log_listener_task_exception(self, e: Exception) -> None:
196151
super()._error_log_listener_task_exception(e)

0 commit comments

Comments
 (0)