Skip to content

Commit e8eb20e

Browse files
committed
Validate provider is correct type for the w3 class on init:
- Disallow sync providers on `AsyncWeb3` init; disallow async providers on `Web3` init. This will provide better messaging to users when they try to use the wrong provider type.
1 parent 4833e25 commit e8eb20e

File tree

4 files changed

+86
-2
lines changed

4 files changed

+86
-2
lines changed

newsfragments/3490.misc.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2+
Validate against sync providers on ``AsyncWeb3`` init; validate against async providers on ``Web3`` init. This will provide better messaging when a web3 class is instantiated with an incompatible provider type.
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import pytest
2+
3+
from web3 import (
4+
AsyncBaseProvider,
5+
AsyncEthereumTesterProvider,
6+
AsyncHTTPProvider,
7+
AsyncIPCProvider,
8+
AsyncWeb3,
9+
BaseProvider,
10+
EthereumTesterProvider,
11+
HTTPProvider,
12+
IPCProvider,
13+
LegacyWebSocketProvider,
14+
Web3,
15+
WebSocketProvider,
16+
)
17+
from web3.exceptions import (
18+
Web3ValidationError,
19+
)
20+
21+
22+
class ExtendsAsyncBaseProvider(AsyncBaseProvider):
23+
pass
24+
25+
26+
class ExtendsBaseProvider(BaseProvider):
27+
pass
28+
29+
30+
@pytest.mark.parametrize(
31+
"provider_class",
32+
(
33+
AsyncBaseProvider,
34+
ExtendsAsyncBaseProvider,
35+
AsyncHTTPProvider,
36+
AsyncIPCProvider,
37+
WebSocketProvider,
38+
AsyncEthereumTesterProvider,
39+
),
40+
)
41+
def test_init_web3_with_async_provider(provider_class):
42+
with pytest.raises(Web3ValidationError):
43+
Web3(provider_class())
44+
45+
46+
@pytest.mark.parametrize(
47+
"provider_class",
48+
(
49+
BaseProvider,
50+
ExtendsBaseProvider,
51+
HTTPProvider,
52+
LegacyWebSocketProvider,
53+
IPCProvider,
54+
EthereumTesterProvider,
55+
),
56+
)
57+
def test_init_async_web3_with_sync_provider(provider_class):
58+
with pytest.raises(Web3ValidationError):
59+
AsyncWeb3(provider_class())

tests/core/utilities/test_attach_modules.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def test_attach_modules_multiple_levels_deep(module1):
8383

8484
def test_attach_modules_with_wrong_module_format():
8585
mods = {"eth": (MockEth, MockEth, MockEth)}
86-
w3 = Web3(EthereumTesterProvider, modules={})
86+
w3 = Web3(EthereumTesterProvider(), modules={})
8787
with pytest.raises(
8888
Web3ValidationError, match="Module definitions can only have 1 or 2 elements"
8989
):
@@ -94,7 +94,7 @@ def test_attach_modules_with_existing_modules():
9494
mods = {
9595
"eth": MockEth,
9696
}
97-
w3 = Web3(EthereumTesterProvider, modules=mods)
97+
w3 = Web3(EthereumTesterProvider(), modules=mods)
9898
with pytest.raises(
9999
Web3AttributeError,
100100
match=("The web3 object already has an attribute with that name"),

web3/main.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
)
8888
from web3.exceptions import (
8989
Web3TypeError,
90+
Web3ValidationError,
9091
Web3ValueError,
9192
)
9293
from web3.geth import (
@@ -350,6 +351,24 @@ def batch_requests(
350351
return self.manager._batch_requests()
351352

352353

354+
def _validate_provider(
355+
w3: Union["Web3", "AsyncWeb3"],
356+
provider: Optional[Union[BaseProvider, AsyncBaseProvider]],
357+
) -> None:
358+
if provider is not None:
359+
if isinstance(w3, AsyncWeb3) and not isinstance(provider, AsyncBaseProvider):
360+
raise Web3ValidationError(
361+
"Provider must be an instance of `AsyncBaseProvider` for "
362+
f"`AsyncWeb3`, got {type(provider)}."
363+
)
364+
365+
if isinstance(w3, Web3) and not isinstance(provider, BaseProvider):
366+
raise Web3ValidationError(
367+
"Provider must be an instance of `BaseProvider` for `Web3`, "
368+
f"got {type(provider)}."
369+
)
370+
371+
353372
class Web3(BaseWeb3):
354373
# mypy types
355374
eth: Eth
@@ -372,6 +391,8 @@ def __init__(
372391
] = None,
373392
ens: Union[ENS, "Empty"] = empty,
374393
) -> None:
394+
_validate_provider(self, provider)
395+
375396
self.manager = self.RequestManager(self, provider, middleware)
376397
self.codec = ABICodec(build_strict_registry())
377398

@@ -440,6 +461,8 @@ def __init__(
440461
] = None,
441462
ens: Union[AsyncENS, "Empty"] = empty,
442463
) -> None:
464+
_validate_provider(self, provider)
465+
443466
self.manager = self.RequestManager(self, provider, middleware)
444467
self.codec = ABICodec(build_strict_registry())
445468

0 commit comments

Comments
 (0)