Skip to content

Commit 1b66d3c

Browse files
darwintreereedsa
andauthored
Refactor for duplicate code in contract internal implementations (#3579)
* refactor: merge ContractFunction and AsyncContractFunction code * refactor: remove functions and events duplicate code * doc: refactor changelog * Update web3/contract/base_contract.py Co-authored-by: Stuart Reed <[email protected]> * Update newsfragments/3579.internal.rst Co-authored-by: Stuart Reed <[email protected]> * chore: remove legacy code and improve comments * Move __call__ into BaseContractEvent --------- Co-authored-by: Stuart Reed <[email protected]>
1 parent fe791ed commit 1b66d3c

File tree

5 files changed

+218
-432
lines changed

5 files changed

+218
-432
lines changed

newsfragments/3579.internal.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Move duplicate code into ``BaseContract`` class from ``Contract`` and ``AsyncContract``. (1) ``ContractFunction`` and ``AsyncContractFunction`` (2) ``ContractFunctions`` and ``AsyncContractFunctions``, (3) ``ContractEvent`` and ``AsyncContractEvent``, and (4) ``ContractEvents`` and ``AsyncContractEvents``.

web3/contract/async_contract.py

Lines changed: 2 additions & 206 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,13 @@
1414

1515
from eth_typing import (
1616
ABI,
17-
ABIFunction,
1817
ChecksumAddress,
1918
)
2019
from eth_utils import (
2120
combomethod,
2221
)
2322
from eth_utils.abi import (
2423
abi_to_signature,
25-
filter_abi_by_type,
2624
get_abi_input_names,
2725
)
2826
from eth_utils.toolz import (
@@ -34,7 +32,6 @@
3432

3533
from web3._utils.abi import (
3634
fallback_func_abi_exists,
37-
get_name_from_abi_element_identifier,
3835
receive_func_abi_exists,
3936
)
4037
from web3._utils.abi_element_identifiers import (
@@ -49,8 +46,6 @@
4946
)
5047
from web3._utils.contracts import (
5148
async_parse_block_identifier,
52-
copy_contract_event,
53-
copy_contract_function,
5449
)
5550
from web3._utils.datatypes import (
5651
PropertyCheckingFactory,
@@ -89,12 +84,6 @@
8984
get_function_by_identifier,
9085
)
9186
from web3.exceptions import (
92-
ABIEventNotFound,
93-
ABIFunctionNotFound,
94-
MismatchedABI,
95-
NoABIEventsFound,
96-
NoABIFound,
97-
NoABIFunctionsFound,
9887
Web3AttributeError,
9988
Web3TypeError,
10089
Web3ValidationError,
@@ -106,12 +95,6 @@
10695
StateOverride,
10796
TxParams,
10897
)
109-
from web3.utils.abi import (
110-
_filter_by_argument_count,
111-
_get_any_abi_signature_with_name,
112-
_mismatched_abi_error_diagnosis,
113-
get_abi_element,
114-
)
11598

11699
if TYPE_CHECKING:
117100
from ens import AsyncENS # noqa: F401
@@ -122,9 +105,6 @@ class AsyncContractEvent(BaseContractEvent):
122105
# mypy types
123106
w3: "AsyncWeb3"
124107

125-
def __call__(self, *args: Any, **kwargs: Any) -> "AsyncContractEvent":
126-
return copy_contract_event(self, *args, **kwargs)
127-
128108
@combomethod
129109
async def get_logs(
130110
self,
@@ -255,162 +235,18 @@ def build_filter(self) -> AsyncEventFilterBuilder:
255235
builder.address = self.address
256236
return builder
257237

258-
@classmethod
259-
def factory(cls, class_name: str, **kwargs: Any) -> Self:
260-
return PropertyCheckingFactory(class_name, (cls,), kwargs)()
261238

262-
263-
class AsyncContractEvents(BaseContractEvents):
239+
class AsyncContractEvents(BaseContractEvents[AsyncContractEvent]):
264240
def __init__(
265241
self, abi: ABI, w3: "AsyncWeb3", address: Optional[ChecksumAddress] = None
266242
) -> None:
267243
super().__init__(abi, w3, AsyncContractEvent, address)
268244

269-
def __iter__(self) -> Iterable["AsyncContractEvent"]:
270-
if not hasattr(self, "_events") or not self._events:
271-
return
272-
273-
for event in self._events:
274-
yield self[abi_to_signature(event)]
275-
276-
def __getattr__(self, event_name: str) -> "AsyncContractEvent":
277-
if super().__getattribute__("abi") is None:
278-
raise NoABIFound(
279-
"There is no ABI found for this contract.",
280-
)
281-
elif "_events" not in self.__dict__ or len(self._events) == 0:
282-
raise NoABIEventsFound(
283-
"The abi for this contract contains no event definitions. ",
284-
"Are you sure you provided the correct contract abi?",
285-
)
286-
elif get_name_from_abi_element_identifier(event_name) not in [
287-
get_name_from_abi_element_identifier(event["name"])
288-
for event in self._events
289-
]:
290-
raise ABIEventNotFound(
291-
f"The event '{event_name}' was not found in this contract's abi. ",
292-
"Are you sure you provided the correct contract abi?",
293-
)
294-
295-
if "(" not in event_name:
296-
event_name = _get_any_abi_signature_with_name(event_name, self._events)
297-
else:
298-
event_name = f"_{event_name}"
299-
300-
return super().__getattribute__(event_name)
301-
302-
def __getitem__(self, event_name: str) -> "AsyncContractEvent":
303-
return getattr(self, event_name)
304-
305245

306246
class AsyncContractFunction(BaseContractFunction):
307247
# mypy types
308248
w3: "AsyncWeb3"
309249

310-
def __call__(self, *args: Any, **kwargs: Any) -> "AsyncContractFunction":
311-
# When a function is called, check arguments to obtain the correct function
312-
# in the contract. self will be used if all args and kwargs are
313-
# encodable to self.abi, otherwise the correct function is obtained from
314-
# the contract.
315-
if (
316-
self.abi_element_identifier in [FallbackFn, ReceiveFn]
317-
or self.abi_element_identifier == "constructor"
318-
):
319-
return copy_contract_function(self, *args, **kwargs)
320-
321-
all_functions = cast(
322-
List[ABIFunction],
323-
filter_abi_by_type(
324-
"function",
325-
self.contract_abi,
326-
),
327-
)
328-
# Filter functions by name to obtain function signatures
329-
function_name = get_name_from_abi_element_identifier(
330-
self.abi_element_identifier
331-
)
332-
function_abis = [
333-
function for function in all_functions if function["name"] == function_name
334-
]
335-
num_args = len(args) + len(kwargs)
336-
function_abis_with_arg_count = cast(
337-
List[ABIFunction],
338-
_filter_by_argument_count(
339-
num_args,
340-
function_abis,
341-
),
342-
)
343-
344-
if not len(function_abis_with_arg_count):
345-
# Build an ABI without arguments to determine if one exists
346-
function_abis_with_arg_count = [
347-
ABIFunction({"type": "function", "name": function_name})
348-
]
349-
350-
# Check that arguments in call match a function ABI
351-
num_attempts = 0
352-
function_abi_matches = []
353-
contract_function = None
354-
for abi in function_abis_with_arg_count:
355-
try:
356-
num_attempts += 1
357-
358-
# Search for a function ABI that matches the arguments used
359-
function_abi_matches.append(
360-
cast(
361-
ABIFunction,
362-
get_abi_element(
363-
function_abis,
364-
abi_to_signature(abi),
365-
*args,
366-
abi_codec=self.w3.codec,
367-
**kwargs,
368-
),
369-
)
370-
)
371-
except MismatchedABI:
372-
# ignore exceptions
373-
continue
374-
375-
if len(function_abi_matches) == 1:
376-
function_abi = function_abi_matches[0]
377-
if abi_to_signature(self.abi) == abi_to_signature(function_abi):
378-
contract_function = self
379-
else:
380-
# Found a match that is not self
381-
contract_function = AsyncContractFunction.factory(
382-
abi_to_signature(function_abi),
383-
w3=self.w3,
384-
contract_abi=self.contract_abi,
385-
address=self.address,
386-
abi_element_identifier=abi_to_signature(function_abi),
387-
abi=function_abi,
388-
)
389-
else:
390-
for abi in function_abi_matches:
391-
if abi_to_signature(self.abi) == abi_to_signature(abi):
392-
contract_function = self
393-
break
394-
else:
395-
# Raise exception if multiple found
396-
raise MismatchedABI(
397-
_mismatched_abi_error_diagnosis(
398-
function_name,
399-
self.contract_abi,
400-
len(function_abi_matches),
401-
num_args,
402-
*args,
403-
abi_codec=self.w3.codec,
404-
**kwargs,
405-
)
406-
)
407-
408-
return copy_contract_function(contract_function, *args, **kwargs)
409-
410-
@classmethod
411-
def factory(cls, class_name: str, **kwargs: Any) -> Self:
412-
return PropertyCheckingFactory(class_name, (cls,), kwargs)()
413-
414250
async def call(
415251
self,
416252
transaction: Optional[TxParams] = None,
@@ -551,7 +387,7 @@ def get_receive_function(
551387
return cast(AsyncContractFunction, NonExistentReceiveFunction())
552388

553389

554-
class AsyncContractFunctions(BaseContractFunctions):
390+
class AsyncContractFunctions(BaseContractFunctions[AsyncContractFunction]):
555391
def __init__(
556392
self,
557393
abi: ABI,
@@ -561,46 +397,6 @@ def __init__(
561397
) -> None:
562398
super().__init__(abi, w3, AsyncContractFunction, address, decode_tuples)
563399

564-
def __iter__(self) -> Iterable["AsyncContractFunction"]:
565-
if not hasattr(self, "_functions") or not self._functions:
566-
return
567-
568-
for func in self._functions:
569-
yield self[abi_to_signature(func)]
570-
571-
def __getattr__(self, function_name: str) -> "AsyncContractFunction":
572-
if super().__getattribute__("abi") is None:
573-
raise NoABIFound(
574-
"There is no ABI found for this contract.",
575-
)
576-
elif "_functions" not in self.__dict__ or len(self._functions) == 0:
577-
raise NoABIFunctionsFound(
578-
"The abi for this contract contains no function definitions. ",
579-
"Are you sure you provided the correct contract abi?",
580-
)
581-
elif get_name_from_abi_element_identifier(function_name) not in [
582-
get_name_from_abi_element_identifier(function["name"])
583-
for function in self._functions
584-
]:
585-
raise ABIFunctionNotFound(
586-
f"The function '{function_name}' was not found in this ",
587-
"contract's abi.",
588-
)
589-
590-
if "(" not in function_name:
591-
function_name = _get_any_abi_signature_with_name(
592-
function_name, self._functions
593-
)
594-
else:
595-
function_name = f"_{function_name}"
596-
597-
return super().__getattribute__(
598-
function_name,
599-
)
600-
601-
def __getitem__(self, function_name: str) -> "AsyncContractFunction":
602-
return getattr(self, function_name)
603-
604400

605401
class AsyncContract(BaseContract):
606402
functions: AsyncContractFunctions = None

0 commit comments

Comments
 (0)