Skip to content

Commit ee1d904

Browse files
committed
Separate handler subscriptions into their own queue:
- Make use of a ``TaskReliant`` queue that is specific for handler subscriptions. Put a ``SubscriptionProcessingFinished`` exceptions in the queue to signal when all handler subscriptions have been unsubscribed from. - Add testing around running the ``handler_subscriptions`` method while listening for un-handled subscriptions via the socket. - Create a container to manage all subscriptions, since we store subs in more than one way: by ``id``, by ``label``, etc... - The ``SubscriptionContainer`` class has an object reference, so we changes share this with the provider's ``RequestProcessor`` class. The request_processor can use the container to make decisions about which queue to store the subscription in, based on the presence or absence of a ``handler``.
1 parent 8d840b3 commit ee1d904

File tree

9 files changed

+382
-67
lines changed

9 files changed

+382
-67
lines changed

tests/core/providers/test_async_ipc_provider.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,11 @@ async def test_disconnect_cleanup(
180180
provider._request_processor._request_response_cache.cache("0", "0x1337")
181181
provider._request_processor._request_information_cache.cache("0", "0x1337")
182182
provider._request_processor._subscription_response_queue.put_nowait({"id": "0"})
183+
provider._request_processor._handler_subscription_queue.put_nowait({"id": "0"})
183184
assert len(provider._request_processor._request_response_cache) == 1
184185
assert len(provider._request_processor._request_information_cache) == 1
185186
assert provider._request_processor._subscription_response_queue.qsize() == 1
187+
assert provider._request_processor._handler_subscription_queue.qsize() == 1
186188

187189
await w3.provider.disconnect()
188190

@@ -192,6 +194,7 @@ async def test_disconnect_cleanup(
192194
assert len(provider._request_processor._request_response_cache) == 0
193195
assert len(provider._request_processor._request_information_cache) == 0
194196
assert provider._request_processor._subscription_response_queue.empty()
197+
assert provider._request_processor._handler_subscription_queue.empty()
195198

196199

197200
async def _raise_connection_closed(*_args, **_kwargs):

tests/core/providers/test_websocket_provider.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
from web3.types import (
3535
RPCEndpoint,
3636
)
37+
from web3.utils import (
38+
EthSubscription,
39+
)
3740

3841

3942
def _mock_ws(provider):
@@ -80,16 +83,19 @@ async def test_disconnect_cleanup():
8083
provider._request_processor._request_response_cache.cache("0", "0x1337")
8184
provider._request_processor._request_information_cache.cache("0", "0x1337")
8285
provider._request_processor._subscription_response_queue.put_nowait({"id": "0"})
86+
provider._request_processor._handler_subscription_queue.put_nowait({"id": "0"})
8387
assert len(provider._request_processor._request_response_cache) == 1
8488
assert len(provider._request_processor._request_information_cache) == 1
8589
assert provider._request_processor._subscription_response_queue.qsize() == 1
90+
assert provider._request_processor._handler_subscription_queue.qsize() == 1
8691

8792
await provider.disconnect()
8893

8994
assert provider._ws is None
9095
assert len(provider._request_processor._request_response_cache) == 0
9196
assert len(provider._request_processor._request_information_cache) == 0
9297
assert provider._request_processor._subscription_response_queue.empty()
98+
assert provider._request_processor._handler_subscription_queue.empty()
9399

94100

95101
@pytest.mark.asyncio
@@ -278,6 +284,9 @@ async def test_listen_event_awaits_msg_processing_when_subscription_queue_is_ful
278284
generate_cache_key(sub_id),
279285
sub_request_information,
280286
)
287+
sub = EthSubscription()
288+
sub._id = sub_id
289+
async_w3.subscription_manager._add_subscription(sub)
281290

282291
mocked_sub = {
283292
"jsonrpc": "2.0",
@@ -381,3 +390,45 @@ async def test_connection_closed_ok_breaks_message_iteration():
381390
w3 = await AsyncWeb3(WebSocketProvider("ws://mocked"))
382391
async for _ in w3.socket.process_subscriptions():
383392
pytest.fail("Should not reach this point.")
393+
394+
395+
@pytest.mark.asyncio
396+
async def test_listener_task_breaks_out_of_stream_when_cancelled():
397+
with patch(
398+
"web3.providers.persistent.websocket.connect",
399+
new=lambda *_1, **_2: _mocked_ws_conn(),
400+
):
401+
async_w3 = await AsyncWeb3(WebSocketProvider("ws://mocked"))
402+
403+
async_w3.provider._message_listener_task.cancel()
404+
sub = EthSubscription()
405+
sub._id = "0x1"
406+
async_w3.subscription_manager._add_subscription(sub)
407+
# this should hang indefinitely if the listener task does not put a
408+
# ``TaskNotRunning`` in the ``_subscription_response_queue`` to break out of
409+
# listening. The call to ``provider._handle_listener_task_exceptions`` bubbles up
410+
# the exception.
411+
with pytest.raises(asyncio.CancelledError):
412+
async for _ in async_w3.socket.process_subscriptions():
413+
...
414+
415+
416+
@pytest.mark.asyncio
417+
async def test_listener_task_breaks_out_of_handle_subscriptions_when_cancelled():
418+
with patch(
419+
"web3.providers.persistent.websocket.connect",
420+
new=lambda *_1, **_2: _mocked_ws_conn(),
421+
):
422+
async_w3 = await AsyncWeb3(WebSocketProvider("ws://mocked"))
423+
424+
async_w3.provider._message_listener_task.cancel()
425+
sub = EthSubscription(handler=AsyncMock())
426+
sub._id = "0x1"
427+
428+
async_w3.subscription_manager._add_subscription(sub)
429+
# this should hang indefinitely if the listener task does not put a
430+
# ``TaskNotRunning`` in the ``_handler_subscription_queue`` to break out of
431+
# listening. The call to ``provider._handle_listener_task_exceptions`` bubbles
432+
# up the exception.
433+
with pytest.raises(asyncio.CancelledError):
434+
await async_w3.subscription_manager.handle_subscriptions(run_forever=True)

tests/core/subscriptions/test_subscription_manager.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,11 @@ async def test_subscription_manager_raises_for_sub_with_the_same_label(
5656

5757
# make sure the subscription was subscribed to and not added to the manager
5858
assert subscription_manager.subscriptions == [sub1]
59-
assert subscription_manager._subscriptions_by_label == {"foo": sub1}
60-
assert subscription_manager._subscriptions_by_id == {"0x0": sub1}
59+
sub_container = subscription_manager._subscription_container
60+
assert len(sub_container) == 1
61+
assert sub_container.subscriptions == [sub1]
62+
assert sub_container.subscriptions_by_id == {"0x0": sub1}
63+
assert sub_container.subscriptions_by_label == {"foo": sub1}
6164

6265

6366
@pytest.mark.asyncio
@@ -97,8 +100,13 @@ async def test_unsubscribe_all_clears_all_subscriptions(subscription_manager):
97100
sub1 = NewHeadsSubscription(label="foo")
98101
sub2 = PendingTxSubscription(label="bar")
99102
await subscription_manager.subscribe([sub1, sub2])
103+
assert subscription_manager.subscriptions == [sub1, sub2]
100104

101105
await subscription_manager.unsubscribe_all()
102106
assert subscription_manager.subscriptions == []
103-
assert subscription_manager._subscriptions_by_id == {}
104-
assert subscription_manager._subscriptions_by_label == {}
107+
108+
sub_container = subscription_manager._subscription_container
109+
assert len(sub_container) == 0
110+
assert sub_container.subscriptions == []
111+
assert sub_container.subscriptions_by_id == {}
112+
assert sub_container.subscriptions_by_label == {}

web3/_utils/module_testing/persistent_connection_provider.py

Lines changed: 108 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
TYPE_CHECKING,
88
Any,
99
Dict,
10+
List,
1011
Tuple,
12+
Union,
1113
cast,
1214
)
1315

@@ -61,6 +63,9 @@
6163
AsyncContract,
6264
AsyncContractFunction,
6365
)
66+
from web3.providers.persistent.subscription_manager import (
67+
SubscriptionContainer,
68+
)
6469

6570

6671
# LogIndexedAndNotIndexed event args
@@ -158,6 +163,12 @@ async def logs_handler(
158163
assert await sub.unsubscribe()
159164

160165

166+
async def idle_handler(
167+
_handler_context: Any,
168+
) -> None:
169+
pass
170+
171+
161172
async def emit_contract_event(
162173
async_w3: AsyncWeb3,
163174
acct: ChecksumAddress,
@@ -194,6 +205,14 @@ async def log_indexed_and_non_indexed_args_task(
194205
)
195206

196207

208+
def assert_no_subscriptions_left(sub_container: "SubscriptionContainer") -> None:
209+
assert len(sub_container) == 0
210+
assert len(sub_container.subscriptions) == 0
211+
assert len(sub_container.subscriptions_by_id) == 0
212+
assert len(sub_container.subscriptions_by_label) == 0
213+
assert len(sub_container.handler_subscriptions) == 0
214+
215+
197216
class PersistentConnectionProviderTest:
198217
@staticmethod
199218
async def seed_transactions_to_geth(
@@ -295,6 +314,13 @@ async def test_async_eth_subscribe_syncing_mocked(
295314
await async_w3.eth.unsubscribe(sub_id)
296315
break
297316

317+
assert_no_subscriptions_left(
318+
async_w3.subscription_manager._subscription_container
319+
)
320+
321+
# cleanup
322+
async_w3.provider._request_processor.clear_caches()
323+
298324
@pytest.mark.asyncio
299325
async def test_async_eth_subscribe_new_heads(self, async_w3: AsyncWeb3) -> None:
300326
sub_id = await async_w3.eth.subscribe("newHeads")
@@ -308,14 +334,16 @@ async def test_async_eth_subscribe_new_heads(self, async_w3: AsyncWeb3) -> None:
308334
break
309335

310336
assert await async_w3.eth.unsubscribe(sub_id)
337+
assert_no_subscriptions_left(
338+
async_w3.subscription_manager._subscription_container
339+
)
311340

312341
@pytest.mark.asyncio
313342
async def test_async_eth_subscribe_creates_and_handles_new_heads_subscription_type(
314343
self,
315344
async_w3: AsyncWeb3,
316345
) -> None:
317346
sub_manager = async_w3.subscription_manager
318-
319347
new_heads_handler_test = SubscriptionHandlerTest()
320348

321349
sub_id = await async_w3.eth.subscribe(
@@ -376,9 +404,9 @@ async def test_async_eth_subscribe_new_and_process_pending_tx_true(
376404

377405
# cleanup
378406
assert await async_w3.eth.unsubscribe(sub_id)
379-
assert len(async_w3.subscription_manager.subscriptions) == 0
380-
assert len(async_w3.subscription_manager._subscriptions_by_id) == 0
381-
assert len(async_w3.subscription_manager._subscriptions_by_label) == 0
407+
assert_no_subscriptions_left(
408+
async_w3.subscription_manager._subscription_container
409+
)
382410
async_w3.provider._request_processor.clear_caches()
383411
await async_w3.eth.wait_for_transaction_receipt(tx_hash)
384412
tx_seeder_task.cancel()
@@ -414,9 +442,9 @@ async def test_async_eth_subscribe_and_process_pending_tx_false(
414442

415443
# cleanup
416444
await async_w3.eth.unsubscribe(sub_id)
417-
assert len(async_w3.subscription_manager.subscriptions) == 0
418-
assert len(async_w3.subscription_manager._subscriptions_by_id) == 0
419-
assert len(async_w3.subscription_manager._subscriptions_by_label) == 0
445+
assert_no_subscriptions_left(
446+
async_w3.subscription_manager._subscription_container
447+
)
420448
await async_w3.eth.wait_for_transaction_receipt(tx_hash)
421449
tx_seeder_task.cancel()
422450

@@ -762,3 +790,76 @@ async def test_sub_handler(
762790

763791
# cleanup
764792
sub_manager.total_handler_calls = 0
793+
794+
@pytest.mark.asyncio
795+
async def test_subscriptions_with_handler_and_without(
796+
self, async_w3: AsyncWeb3
797+
) -> None:
798+
handler_test = SubscriptionHandlerTest()
799+
stream_passed = False
800+
801+
async def test_sub_handler(
802+
handler_context: NewHeadsSubscriptionContext,
803+
) -> None:
804+
handler_context.handler_test.passed = True
805+
await handler_context.subscription.unsubscribe()
806+
807+
async def handle_subscription_stream() -> None:
808+
nonlocal stream_passed
809+
async for msg in async_w3.socket.process_subscriptions():
810+
response = cast(FormattedEthSubscriptionResponse, msg)
811+
assert sub_manager.get_by_id(response["subscription"]) is not None
812+
assert response["result"] is not None
813+
# wait for the handler to unsubscribe:
814+
stream_passed = True
815+
await async_w3.eth.unsubscribe(response["subscription"])
816+
break
817+
818+
await async_w3.eth.subscribe(
819+
"newHeads",
820+
handler=test_sub_handler,
821+
label="managed",
822+
handler_context={"handler_test": handler_test},
823+
)
824+
await async_w3.eth.subscribe("newHeads", label="streamed")
825+
826+
sub_manager = async_w3.subscription_manager
827+
assert len(sub_manager.subscriptions) == 2
828+
829+
await asyncio.gather(
830+
sub_manager.handle_subscriptions(),
831+
handle_subscription_stream(),
832+
)
833+
834+
assert len(sub_manager.subscriptions) == 0
835+
assert sub_manager.total_handler_calls == 1
836+
assert handler_test.passed
837+
assert stream_passed
838+
839+
# cleanup
840+
sub_manager.total_handler_calls = 0
841+
842+
@pytest.mark.asyncio
843+
async def test_handle_subscriptions_breaks_on_unsubscribe(
844+
self,
845+
async_w3: AsyncWeb3,
846+
) -> None:
847+
async def unsubscribe_subs(
848+
subs: List[Union[NewHeadsSubscription, LogsSubscription]]
849+
) -> None:
850+
for sub in subs:
851+
await sub.unsubscribe()
852+
853+
sub_manager = async_w3.subscription_manager
854+
sub1 = NewHeadsSubscription(label="foo", handler=idle_handler)
855+
sub2 = LogsSubscription(label="bar", handler=idle_handler)
856+
await sub_manager.subscribe([sub1, sub2])
857+
assert sub_manager.subscriptions == [sub1, sub2]
858+
859+
asyncio.create_task(unsubscribe_subs([sub1, sub2]))
860+
# With no subscriptions in the queue, ``handle_subscriptions`` should hang
861+
# indefinitely. Test that when the last subscription is unsubscribed from,
862+
# the method breaks out of the loop. This is done via a raised
863+
# ``SubscriptionProcessingFinished`` within the ``TaskReliantQueue``.
864+
await sub_manager.handle_subscriptions()
865+
assert_no_subscriptions_left(sub_manager._subscription_container)

web3/exceptions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,13 @@ class PersistentConnectionClosedOK(PersistentConnectionError):
353353
"""
354354

355355

356+
class SubscriptionProcessingFinished(Web3Exception):
357+
"""
358+
Raised to alert the subscription manager that the processing of subscriptions
359+
has finished.
360+
"""
361+
362+
356363
class Web3RPCError(Web3Exception):
357364
"""
358365
Raised when a JSON-RPC response contains an error field.

0 commit comments

Comments
 (0)