Skip to content

Commit e9ecd9d

Browse files
committed
Gracefully exit message iterator if ConnectionClosedOK is raised
- Similar changes as #3424 that were implemented for ``v6``. - Raise ``PersistentConnectionClosedOK`` if a ``ConnectionClosedOK`` is raised by the server for a websocket provider. Handle this appropriately in a pattern than can be applicable to other persistent connections by creating a new internal exception that is then handled, rather than handling the specific websocket exception.
1 parent 39cde3b commit e9ecd9d

File tree

8 files changed

+55
-25
lines changed

8 files changed

+55
-25
lines changed

newsfragments/3432.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Handle ``ConnectionClosedOK`` case for ``WebSocketProvider``. If a persistent connection is closed gracefully, log and raise a silent ``PersistentConnectionClosedOK`` exception, triggering an end to the message listener task and breaking out of the ``process_subscriptions()`` iterator.

tests/core/providers/test_websocket_provider.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313
from websockets import (
1414
ConnectionClosed,
15+
ConnectionClosedOK,
1516
)
1617

1718
from web3 import (
@@ -365,3 +366,16 @@ async def test_async_iterator_pattern_exception_handling_for_subscriptions():
365366
pytest.fail("Expected `ConnectionClosed` exception.")
366367

367368
assert iterations == 3
369+
370+
371+
@pytest.mark.asyncio
372+
async def test_connection_closed_ok_breaks_message_iteration():
373+
with patch(
374+
"web3.providers.persistent.websocket.connect",
375+
new=lambda *_1, **_2: WebSocketMessageStreamMock(
376+
raise_exception=ConnectionClosedOK(None, None)
377+
),
378+
):
379+
w3 = await AsyncWeb3(WebSocketProvider("ws://mocked"))
380+
async for _ in w3.socket.process_subscriptions():
381+
pytest.fail("Should not reach this point.")

web3/_utils/module_testing/module_testing_utils.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
from collections import (
2-
deque,
3-
)
1+
import asyncio
42
from typing import (
53
TYPE_CHECKING,
64
Any,
@@ -179,7 +177,9 @@ class WebSocketMessageStreamMock:
179177
def __init__(
180178
self, messages: Collection[bytes] = None, raise_exception: Exception = None
181179
) -> None:
182-
self.messages = deque(messages) if messages else deque()
180+
self.queue = asyncio.Queue() # type: ignore # py38 issue
181+
for msg in messages or []:
182+
self.queue.put_nowait(msg)
183183
self.raise_exception = raise_exception
184184

185185
def __await__(self) -> Generator[Any, Any, "Self"]:
@@ -192,13 +192,12 @@ def __aiter__(self) -> "Self":
192192
return self
193193

194194
async def __anext__(self) -> bytes:
195+
return await self.queue.get()
196+
197+
async def recv(self) -> bytes:
195198
if self.raise_exception:
196199
raise self.raise_exception
197-
198-
elif len(self.messages) == 0:
199-
raise StopAsyncIteration
200-
201-
return self.messages.popleft()
200+
return await self.queue.get()
202201

203202
@staticmethod
204203
async def pong() -> Literal[False]:

web3/exceptions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,12 @@ def __init__(
327327
super().__init__(message)
328328

329329

330+
class PersistentConnectionClosedOK(Web3Exception):
331+
"""
332+
Raised when a persistent connection is closed gracefully by the server.
333+
"""
334+
335+
330336
class Web3RPCError(Web3Exception):
331337
"""
332338
Raised when a JSON-RPC response contains an error field.

web3/manager.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@
2020
from hexbytes import (
2121
HexBytes,
2222
)
23-
from websockets.exceptions import (
24-
ConnectionClosedOK,
25-
)
2623

2724
from web3._utils.batching import (
2825
RequestBatcher,
@@ -487,7 +484,7 @@ async def send(self, method: RPCEndpoint, params: Any) -> RPCResponse:
487484
def _persistent_message_stream(self) -> "_AsyncPersistentMessageStream":
488485
return _AsyncPersistentMessageStream(self)
489486

490-
async def _get_next_message(self) -> Any:
487+
async def _get_next_message(self) -> Optional[RPCResponse]:
491488
return await self._message_stream().__anext__()
492489

493490
async def _message_stream(self) -> AsyncGenerator[RPCResponse, None]:
@@ -515,12 +512,13 @@ async def _message_stream(self) -> AsyncGenerator[RPCResponse, None]:
515512
# if response is an active subscription response, process it
516513
yield await self._process_response(response)
517514
except TaskNotRunning:
515+
await asyncio.sleep(0)
518516
self._provider._handle_listener_task_exceptions()
519517
self.logger.error(
520518
"Message listener background task has stopped unexpectedly. "
521519
"Stopping message stream."
522520
)
523-
raise StopAsyncIteration
521+
return
524522

525523
async def _process_response(self, response: RPCResponse) -> RPCResponse:
526524
provider = cast(PersistentConnectionProvider, self._provider)
@@ -586,7 +584,4 @@ def __aiter__(self) -> Self:
586584
return self
587585

588586
async def __anext__(self) -> RPCResponse:
589-
try:
590-
return await self.manager._get_next_message()
591-
except ConnectionClosedOK:
592-
raise StopAsyncIteration
587+
return await self.manager._get_next_message()

web3/providers/persistent/persistent.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
generate_cache_key,
2020
)
2121
from web3.exceptions import (
22+
PersistentConnectionClosedOK,
2223
ProviderConnectionError,
2324
TaskNotRunning,
2425
TimeExhausted,
@@ -160,6 +161,13 @@ async def _message_listener(self) -> None:
160161
await asyncio.sleep(0)
161162
try:
162163
await self._provider_specific_message_listener()
164+
except PersistentConnectionClosedOK as e:
165+
self.logger.info(
166+
"Message listener background task has ended gracefully: "
167+
f"{e.user_message}"
168+
)
169+
# trigger a return to end the listener task and initiate the callback fn
170+
return
163171
except Exception as e:
164172
if not self.silence_listener_task_exceptions:
165173
raise e
@@ -202,10 +210,6 @@ async def _match_response_id_to_request_id() -> RPCResponse:
202210
request_cache_key = generate_cache_key(request_id)
203211

204212
while True:
205-
# check if an exception was recorded in the listener task and raise it
206-
# in the main loop if so
207-
self._handle_listener_task_exceptions()
208-
209213
if request_cache_key in self._request_processor._request_response_cache:
210214
self.logger.debug(
211215
f"Popping response for id {request_id} from cache."
@@ -215,6 +219,9 @@ async def _match_response_id_to_request_id() -> RPCResponse:
215219
)
216220
return popped_response
217221
else:
222+
# check if an exception was recorded in the listener task and raise
223+
# it in the main loop if so
224+
self._handle_listener_task_exceptions()
218225
await asyncio.sleep(0)
219226

220227
try:

web3/providers/persistent/persistent_connection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
TYPE_CHECKING,
33
Any,
44
Dict,
5+
Optional,
56
)
67

78
from web3.types import (
@@ -35,7 +36,7 @@ def subscriptions(self) -> Dict[str, Any]:
3536
async def send(self, method: RPCEndpoint, params: Any) -> RPCResponse:
3637
return await self._manager.send(method, params)
3738

38-
async def recv(self) -> Any:
39+
async def recv(self) -> Optional[RPCResponse]:
3940
return await self._manager._get_next_message()
4041

4142
def process_subscriptions(self) -> "_AsyncPersistentMessageStream":

web3/providers/persistent/websocket.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
connect,
2626
)
2727
from websockets.exceptions import (
28+
ConnectionClosedOK,
2829
WebSocketException,
2930
)
3031

@@ -36,6 +37,7 @@
3637
async_handle_request_caching,
3738
)
3839
from web3.exceptions import (
40+
PersistentConnectionClosedOK,
3941
ProviderConnectionError,
4042
Web3ValidationError,
4143
)
@@ -169,8 +171,13 @@ async def make_batch_request(
169171
return response
170172

171173
async def _provider_specific_message_listener(self) -> None:
172-
async for raw_message in self._ws:
173-
await asyncio.sleep(0)
174+
while True:
175+
try:
176+
raw_message = await self._ws.recv()
177+
except ConnectionClosedOK:
178+
raise PersistentConnectionClosedOK(
179+
user_message="WebSocket connection received `ConnectionClosedOK`."
180+
)
174181

175182
response = json.loads(raw_message)
176183
if isinstance(response, list):

0 commit comments

Comments
 (0)