Skip to content

Commit eebcd84

Browse files
committed
Better organize provider property assignment and superclass calls:
- Explicitly set instance propertiece in the constructor, rather than the class. - Order property assignments in the constructor so that the superclass is initialized first, with explicit mention when it's useful to initialize a particular property on the subclass first.
1 parent 2fd948f commit eebcd84

File tree

10 files changed

+24
-37
lines changed

10 files changed

+24
-37
lines changed

web3/main.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,8 @@ def ens(self, new_ens: Union[AsyncENS, "Empty"]) -> None:
513513

514514
# -- persistent connection settings -- #
515515

516-
_subscription_manager: SubscriptionManager = None
516+
_subscription_manager: Optional[SubscriptionManager] = None
517+
_persistent_connection: Optional["PersistentConnection"] = None
517518

518519
@property
519520
@persistent_connection_provider_method()
@@ -523,13 +524,14 @@ def subscription_manager(self) -> SubscriptionManager:
523524
"""
524525
if not self._subscription_manager:
525526
self._subscription_manager = SubscriptionManager(self)
526-
527527
return self._subscription_manager
528528

529529
@property
530530
@persistent_connection_provider_method()
531531
def socket(self) -> PersistentConnection:
532-
return PersistentConnection(self)
532+
if self._persistent_connection is None:
533+
self._persistent_connection = PersistentConnection(self)
534+
return self._persistent_connection
533535

534536
# w3 = await AsyncWeb3(PersistentConnectionProvider(...))
535537
@persistent_connection_provider_method(

web3/manager.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,6 @@ def __init__(
292292
provider = cast(PersistentConnectionProvider, self.provider)
293293
self._request_processor: RequestProcessor = provider._request_processor
294294

295-
w3: Union["AsyncWeb3", "Web3"] = None
296-
_provider = None
297-
298295
@property
299296
def provider(self) -> Union["BaseProvider", "AsyncBaseProvider"]:
300297
return self._provider

web3/providers/async_base.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,6 @@ class AsyncBaseProvider:
8383
global_ccip_read_enabled: bool = True
8484
ccip_read_max_redirects: int = 4
8585

86-
# request caching
87-
_request_cache: SimpleCache
88-
_request_cache_lock: asyncio.Lock = asyncio.Lock()
89-
9086
def __init__(
9187
self,
9288
cache_allowed_requests: bool = False,
@@ -96,6 +92,8 @@ def __init__(
9692
] = empty,
9793
) -> None:
9894
self._request_cache = SimpleCache(1000)
95+
self._request_cache_lock: asyncio.Lock = asyncio.Lock()
96+
9997
self.cache_allowed_requests = cache_allowed_requests
10098
self.cacheable_requests = cacheable_requests or CACHEABLE_REQUESTS
10199
self.request_cache_validation_threshold = request_cache_validation_threshold
@@ -172,11 +170,9 @@ async def disconnect(self) -> None:
172170

173171

174172
class AsyncJSONBaseProvider(AsyncBaseProvider):
175-
logger = logging.getLogger("web3.providers.async_base.AsyncJSONBaseProvider")
176-
177173
def __init__(self, **kwargs: Any) -> None:
178-
self.request_counter = itertools.count()
179174
super().__init__(**kwargs)
175+
self.request_counter = itertools.count()
180176

181177
def encode_rpc_request(self, method: RPCEndpoint, params: Any) -> bytes:
182178
request_id = next(self.request_counter)

web3/providers/base.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,6 @@ class BaseProvider:
6666
global_ccip_read_enabled: bool = True
6767
ccip_read_max_redirects: int = 4
6868

69-
# request caching
70-
_request_cache: SimpleCache
71-
_request_cache_lock: threading.Lock = threading.Lock()
72-
7369
def __init__(
7470
self,
7571
cache_allowed_requests: bool = False,
@@ -79,6 +75,8 @@ def __init__(
7975
] = empty,
8076
) -> None:
8177
self._request_cache = SimpleCache(1000)
78+
self._request_cache_lock: threading.Lock = threading.Lock()
79+
8280
self.cache_allowed_requests = cache_allowed_requests
8381
self.cacheable_requests = cacheable_requests or CACHEABLE_REQUESTS
8482
self.request_cache_validation_threshold = request_cache_validation_threshold
@@ -124,8 +122,8 @@ class JSONBaseProvider(BaseProvider):
124122
] = (None, None)
125123

126124
def __init__(self, **kwargs: Any) -> None:
127-
self.request_counter = itertools.count()
128125
super().__init__(**kwargs)
126+
self.request_counter = itertools.count()
129127

130128
def encode_rpc_request(self, method: RPCEndpoint, params: Any) -> bytes:
131129
rpc_dict = {

web3/providers/ipc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def __init__(
146146
timeout: int = 30,
147147
**kwargs: Any,
148148
) -> None:
149+
super().__init__(**kwargs)
149150
if ipc_path is None:
150151
self.ipc_path = get_default_ipc_path()
151152
elif isinstance(ipc_path, str) or isinstance(ipc_path, Path):
@@ -156,7 +157,6 @@ def __init__(
156157
self.timeout = timeout
157158
self._lock = threading.Lock()
158159
self._socket = PersistantSocket(self.ipc_path)
159-
super().__init__(**kwargs)
160160

161161
def __str__(self) -> str:
162162
return f"<{self.__class__.__name__} {self.ipc_path}>"

web3/providers/legacy_websocket.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __init__(
102102
websocket_timeout: int = DEFAULT_WEBSOCKET_TIMEOUT,
103103
**kwargs: Any,
104104
) -> None:
105+
super().__init__(**kwargs)
105106
self.endpoint_uri = URI(endpoint_uri)
106107
self.websocket_timeout = websocket_timeout
107108
if self.endpoint_uri is None:
@@ -120,7 +121,6 @@ def __init__(
120121
f"in websocket_kwargs, found: {found_restricted_keys}"
121122
)
122123
self.conn = PersistentWebSocket(self.endpoint_uri, websocket_kwargs)
123-
super().__init__(**kwargs)
124124

125125
def __str__(self) -> str:
126126
return f"WS connection {self.endpoint_uri}"

web3/providers/persistent/async_ipc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,15 @@ def __init__(
5959
# `PersistentConnectionProvider` kwargs can be passed through
6060
**kwargs: Any,
6161
) -> None:
62+
# initialize the ipc_path before calling the super constructor
6263
if ipc_path is None:
6364
self.ipc_path = get_default_ipc_path()
6465
elif isinstance(ipc_path, str) or isinstance(ipc_path, Path):
6566
self.ipc_path = str(Path(ipc_path).expanduser().resolve())
6667
else:
6768
raise Web3TypeError("ipc_path must be of type string or pathlib.Path")
68-
69-
self.read_buffer_limit = read_buffer_limit
7069
super().__init__(**kwargs)
70+
self.read_buffer_limit = read_buffer_limit
7171

7272
def __str__(self) -> str:
7373
return f"<{self.__class__.__name__} {self.ipc_path}>"

web3/providers/persistent/persistent.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,30 +53,26 @@ class PersistentConnectionProvider(AsyncJSONBaseProvider, ABC):
5353
logger = logging.getLogger("web3.providers.PersistentConnectionProvider")
5454
has_persistent_connection = True
5555

56-
_request_processor: RequestProcessor
57-
_message_listener_task: Optional["asyncio.Task[None]"] = None
58-
_listen_event: asyncio.Event = asyncio.Event()
59-
60-
_batch_request_counter: Optional[int] = None
61-
6256
def __init__(
6357
self,
6458
request_timeout: float = DEFAULT_PERSISTENT_CONNECTION_TIMEOUT,
6559
subscription_response_queue_size: int = 500,
6660
silence_listener_task_exceptions: bool = False,
6761
max_connection_retries: int = 5,
68-
label: Optional[str] = None,
6962
**kwargs: Any,
7063
) -> None:
7164
super().__init__(**kwargs)
7265
self._request_processor = RequestProcessor(
7366
self,
7467
subscription_response_queue_size=subscription_response_queue_size,
7568
)
69+
self._message_listener_task: Optional["asyncio.Task[None]"] = None
70+
self._batch_request_counter: Optional[int] = None
71+
self._listen_event: asyncio.Event = asyncio.Event()
72+
self._max_connection_retries = max_connection_retries
73+
7674
self.request_timeout = request_timeout
7775
self.silence_listener_task_exceptions = silence_listener_task_exceptions
78-
self._max_connection_retries = max_connection_retries
79-
self.label = label or self.get_endpoint_uri_or_ipc_path()
8076

8177
def get_endpoint_uri_or_ipc_path(self) -> str:
8278
if hasattr(self, "endpoint_uri"):

web3/providers/persistent/websocket.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,18 +59,19 @@ class WebSocketProvider(PersistentConnectionProvider):
5959
logger = logging.getLogger("web3.providers.WebSocketProvider")
6060
is_async: bool = True
6161

62-
_ws: Optional[WebSocketClientProtocol] = None
63-
6462
def __init__(
6563
self,
6664
endpoint_uri: Optional[Union[URI, str]] = None,
6765
websocket_kwargs: Optional[Dict[str, Any]] = None,
6866
# `PersistentConnectionProvider` kwargs can be passed through
6967
**kwargs: Any,
7068
) -> None:
69+
# initialize the endpoint_uri before calling the super constructor
7170
self.endpoint_uri = (
7271
URI(endpoint_uri) if endpoint_uri is not None else get_default_endpoint()
7372
)
73+
super().__init__(**kwargs)
74+
self._ws: Optional[WebSocketClientProtocol] = None
7475

7576
if not any(
7677
self.endpoint_uri.startswith(prefix)
@@ -93,8 +94,6 @@ def __init__(
9394

9495
self.websocket_kwargs = merge(DEFAULT_WEBSOCKET_KWARGS, websocket_kwargs or {})
9596

96-
super().__init__(**kwargs)
97-
9897
def __str__(self) -> str:
9998
return f"WebSocket connection: {self.endpoint_uri}"
10099

web3/providers/rpc/rpc.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(
7171
] = empty,
7272
**kwargs: Any,
7373
) -> None:
74+
super().__init__(**kwargs)
7475
self._request_session_manager = HTTPSessionManager()
7576

7677
if endpoint_uri is None:
@@ -88,8 +89,6 @@ def __init__(
8889
self.endpoint_uri, session
8990
)
9091

91-
super().__init__(**kwargs)
92-
9392
def __str__(self) -> str:
9493
return f"RPC connection {self.endpoint_uri}"
9594

0 commit comments

Comments
 (0)