Skip to content

Commit 38e11f6

Browse files
committed
feat: Add a subscription manager to manage subscriptions
- Create a SubscriptionManager class to manage subscriptions. - Implement subscribing to one or many subscriptions. - Implement unsubscribing from one or many subscriptions.
1 parent 8a90577 commit 38e11f6

File tree

7 files changed

+321
-16
lines changed

7 files changed

+321
-16
lines changed

web3/eth/async_eth.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,12 @@
9999
_Hash32,
100100
)
101101
from web3.utils import (
102+
EthSubscription,
102103
async_handle_offchain_lookup,
103104
)
105+
from web3.utils.subscriptions import (
106+
EthSubscriptionHandler,
107+
)
104108

105109
if TYPE_CHECKING:
106110
from web3 import AsyncWeb3 # noqa: F401
@@ -714,17 +718,22 @@ async def subscribe(
714718
bool, # newPendingTransactions, full_transactions
715719
]
716720
] = None,
721+
handler: Optional[EthSubscriptionHandler] = None,
722+
label: Optional[str] = None,
717723
) -> HexStr:
718724
if not isinstance(self.w3.provider, PersistentConnectionProvider):
719725
raise MethodNotSupported(
720726
"eth_subscribe is only supported with providers that support "
721727
"persistent connections."
722728
)
723729

724-
if subscription_arg is None:
725-
return await self._subscribe(subscription_type)
726-
727-
return await self._subscribe_with_args(subscription_type, subscription_arg)
730+
params = (
731+
(subscription_type, subscription_arg)
732+
if subscription_arg is not None
733+
else (subscription_type,)
734+
)
735+
sx = EthSubscription(subscription_params=params, handler=handler, label=label)
736+
return await self.w3.subscription_manager.subscribe(sx)
728737

729738
_unsubscribe: Method[Callable[[HexStr], Awaitable[bool]]] = Method(
730739
RPC.eth_unsubscribe,
@@ -738,7 +747,14 @@ async def unsubscribe(self, subscription_id: HexStr) -> bool:
738747
"persistent connections."
739748
)
740749

741-
return await self._unsubscribe(subscription_id)
750+
for sx in self.w3.subscription_manager.subscriptions:
751+
if sx._id == subscription_id:
752+
return await sx.unsubscribe()
753+
754+
raise Web3ValueError(
755+
f"Cannot unsubscribe subscription with id `{subscription_id}`. "
756+
"Subscription not found."
757+
)
742758

743759
# -- contract methods -- #
744760

web3/main.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@
148148
from web3.types import (
149149
Wei,
150150
)
151+
from web3.providers.persistent.subscription_manager import (
152+
SubscriptionManager,
153+
)
151154

152155
if TYPE_CHECKING:
153156
from web3._utils.batching import RequestBatcher # noqa: F401
@@ -508,7 +511,20 @@ def ens(self, new_ens: Union[AsyncENS, "Empty"]) -> None:
508511
new_ens.w3 = self # set self object reference for ``AsyncENS.w3``
509512
self._ens = new_ens
510513

511-
# -- persistent connection methods -- #
514+
# -- persistent connection settings -- #
515+
516+
_subscription_manager: SubscriptionManager = None
517+
518+
@property
519+
@persistent_connection_provider_method()
520+
def subscription_manager(self) -> SubscriptionManager:
521+
"""
522+
Access the subscription manager for the current PersistentConnectionProvider.
523+
"""
524+
if not self._subscription_manager:
525+
self._subscription_manager = SubscriptionManager(self)
526+
527+
return self._subscription_manager
512528

513529
@property
514530
@persistent_connection_provider_method()

web3/manager.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -543,15 +543,18 @@ async def recv(self) -> RPCResponse:
543543
def _persistent_message_stream(self) -> "_AsyncPersistentMessageStream":
544544
return _AsyncPersistentMessageStream(self)
545545

546-
async def _get_next_message(self) -> RPCResponse:
546+
async def _get_next_message(self) -> Optional[RPCResponse]:
547547
return await self._message_stream().__anext__()
548548

549-
async def _message_stream(self) -> AsyncGenerator[RPCResponse, None]:
549+
async def _message_stream(
550+
self,
551+
) -> AsyncGenerator[Optional[RPCResponse], None]:
550552
if not isinstance(self._provider, PersistentConnectionProvider):
551553
raise Web3TypeError(
552554
"Only providers that maintain an open, persistent connection "
553555
"can listen to streams."
554556
)
557+
async_w3 = cast("AsyncWeb3", self.w3)
555558

556559
if self._provider._message_listener_task is None:
557560
raise ProviderConnectionError(
@@ -563,13 +566,25 @@ async def _message_stream(self) -> AsyncGenerator[RPCResponse, None]:
563566
response = await self._request_processor.pop_raw_response(
564567
subscription=True
565568
)
566-
if (
567-
response is not None
568-
and response.get("params", {}).get("subscription")
569-
in self._request_processor.active_subscriptions
570-
):
571-
# if response is an active subscription response, process it
572-
yield await self._process_response(response)
569+
if response is not None:
570+
sx_response = await self._process_response(response)
571+
572+
# call the handler is there is one, else yield the response to any
573+
# listeners
574+
for sx in async_w3.subscription_manager.subscriptions:
575+
if (
576+
sx_response.get("subscription") == sx.id
577+
and sx._handler is not None
578+
):
579+
await sx._handler(async_w3, sx, sx_response["result"])
580+
yield sx_response
581+
else:
582+
# if response is not an active subscription response, log it
583+
self.logger.debug(
584+
"Received inactive subscription from socket:\n"
585+
f" {self._provider.get_endpoint_uri_or_ipc_path()}, "
586+
f" response: {response}"
587+
)
573588
except TaskNotRunning:
574589
await asyncio.sleep(0)
575590
self._provider._handle_listener_task_exceptions()
@@ -642,5 +657,5 @@ def __init__(self, manager: RequestManager, *args: Any, **kwargs: Any) -> None:
642657
def __aiter__(self) -> Self:
643658
return self
644659

645-
async def __anext__(self) -> RPCResponse:
660+
async def __anext__(self) -> Optional[RPCResponse]:
646661
return await self.manager._get_next_message()

web3/providers/persistent/persistent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(
6565
subscription_response_queue_size: int = 500,
6666
silence_listener_task_exceptions: bool = False,
6767
max_connection_retries: int = 5,
68+
label: Optional[str] = None,
6869
**kwargs: Any,
6970
) -> None:
7071
super().__init__(**kwargs)
@@ -75,6 +76,7 @@ def __init__(
7576
self.request_timeout = request_timeout
7677
self.silence_listener_task_exceptions = silence_listener_task_exceptions
7778
self._max_connection_retries = max_connection_retries
79+
self.label = label or self.get_endpoint_uri_or_ipc_path()
7880

7981
def get_endpoint_uri_or_ipc_path(self) -> str:
8082
if hasattr(self, "endpoint_uri"):
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from typing import (
2+
TYPE_CHECKING,
3+
List,
4+
Sequence,
5+
Union,
6+
cast,
7+
overload,
8+
)
9+
10+
from eth_typing import (
11+
HexStr,
12+
)
13+
from typing_extensions import (
14+
TypeVar,
15+
)
16+
17+
from web3.exceptions import (
18+
Web3TypeError,
19+
Web3ValueError,
20+
)
21+
from web3.utils import (
22+
EthSubscription,
23+
)
24+
25+
if TYPE_CHECKING:
26+
from web3 import AsyncWeb3 # noqa: F401
27+
from web3.providers.persistent import PersistentConnectionProvider # noqa: F401
28+
29+
30+
T = TypeVar("T", bound="EthSubscription")
31+
32+
33+
class SubscriptionManager:
34+
def __init__(self, w3: "AsyncWeb3") -> None:
35+
self._w3 = w3
36+
self._provider = cast("PersistentConnectionProvider", w3.provider)
37+
self.subscriptions: List[EthSubscription] = []
38+
39+
@overload
40+
async def subscribe(self, subscriptions: T) -> HexStr:
41+
...
42+
43+
@overload
44+
async def subscribe(self, subscriptions: Sequence[T]) -> List[HexStr]:
45+
...
46+
47+
async def subscribe(
48+
self, subscriptions: Union[T, Sequence[T]]
49+
) -> Union[HexStr, List[HexStr]]:
50+
if isinstance(subscriptions, EthSubscription):
51+
subscriptions._manager = self
52+
sx_id = await self._w3.eth._subscribe(*subscriptions.subscription_params)
53+
subscriptions._id = sx_id
54+
self.subscriptions.append(subscriptions)
55+
self._provider.logger.info(
56+
f"Successfully subscribed to subscription:\n "
57+
f"label: {subscriptions.label}\n id: {sx_id}"
58+
)
59+
return sx_id
60+
elif isinstance(subscriptions, Sequence):
61+
if len(subscriptions) == 0:
62+
raise Web3ValueError("No subscriptions provided.")
63+
64+
sx_ids = []
65+
for sx in subscriptions:
66+
await self.subscribe(sx)
67+
sx_ids.append(sx._id)
68+
69+
return sx_ids
70+
else:
71+
raise Web3TypeError(
72+
"Expected a Subscription or a sequence of Subscriptions."
73+
)
74+
75+
async def unsubscribe(self, subscription: EthSubscription) -> bool:
76+
if subscription not in self.subscriptions:
77+
raise Web3ValueError(
78+
f"Subscription not found or is not being managed by the subscription "
79+
f"manager.\n label: {subscription.label}\n id: {subscription._id}"
80+
)
81+
82+
if await self._w3.eth._unsubscribe(subscription.id):
83+
self.subscriptions.remove(subscription)
84+
return True
85+
return False
86+
87+
async def unsubscribe_all(self) -> None:
88+
for sx in self.subscriptions:
89+
await self.unsubscribe(sx)
90+
91+
self._provider.logger.info("Successfully unsubscribed from all subscriptions.")
92+
93+
async def _handle_subscriptions(self, run_forever: bool = False) -> None:
94+
self._provider.logger.info("Subscription manager processing started.")
95+
while True:
96+
if not run_forever and len(self.subscriptions) == 0:
97+
break
98+
await self._w3.manager._get_next_message()
99+
100+
self._provider.logger.info("Subscription manager processing ended.")
101+
102+
async def handle_subscriptions(self, run_forever: bool = False) -> None:
103+
await self._handle_subscriptions(run_forever=run_forever)

web3/utils/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@
4242
from .exception_handling import (
4343
handle_offchain_lookup,
4444
)
45+
from .subscriptions import (
46+
EthSubscription,
47+
)
4548

4649
__all__ = [
4750
"abi_to_signature",
@@ -70,5 +73,6 @@
7073
"async_handle_offchain_lookup",
7174
"RequestCacheValidationThreshold",
7275
"SimpleCache",
76+
"EthSubscription",
7377
"handle_offchain_lookup",
7478
]

0 commit comments

Comments
 (0)