Skip to content

Commit e5c67ee

Browse files
committed
multicall performance inprovements. abi encoding and decoding still is the most time intensive part of the multicall implementation...
1 parent 50a55f2 commit e5c67ee

File tree

1 file changed

+86
-43
lines changed

1 file changed

+86
-43
lines changed

IceCreamSwapWeb3/Multicall.py

Lines changed: 86 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import eth_utils
88
import rlp
99
from eth_utils import to_checksum_address, to_bytes
10-
from web3._utils.abi import get_abi_output_types
10+
from web3._utils.abi import get_abi_output_types, get_abi_input_types
1111
from web3.contract.contract import ContractFunction, ContractConstructor
1212
from web3.exceptions import ContractLogicError
1313

@@ -31,7 +31,7 @@ class MultiCall:
3131
CALLER_ADDRESS = "0x0000000000000000000000000000000000000123"
3232

3333
MULTICALL_DEPLOYMENTS: dict[int, str] = {
34-
1116: "0x2C310a21E21a3eaDF4e53E1118aeD4614c51B576"
34+
1116: "0x2cd05AcF9aBe54D57eb1E6B12f2129880fA4cF65",
3535
}
3636

3737
@classmethod
@@ -75,46 +75,59 @@ def call(self, use_revert: Optional[bool] = None, batch_size: int = 1_000):
7575
if use_revert is None:
7676
use_revert = self.w3.revert_reason_available
7777

78-
return self._inner_call(use_revert=use_revert, calls=self.calls, batch_size=batch_size)
78+
calls = self.calls
79+
calls_with_calldata = self.add_calls_calldata(calls)
7980

80-
def _inner_call(self, use_revert: bool, calls: list[ContractFunction], batch_size: int):
81+
return self._inner_call(use_revert=use_revert, calls_with_calldata=calls_with_calldata, batch_size=batch_size)
82+
83+
def _inner_call(
84+
self,
85+
use_revert: bool,
86+
calls_with_calldata: list[tuple[ContractFunction, bytes]],
87+
batch_size: int
88+
):
8189
kwargs = dict(
8290
use_revert=use_revert,
8391
batch_size=batch_size,
8492
)
8593
# make sure calls are not bigger than batch_size
86-
if len(calls) > batch_size:
94+
if len(calls_with_calldata) > batch_size:
8795
results = []
88-
for start in range(0, len(calls), batch_size):
96+
for start in range(0, len(calls_with_calldata), batch_size):
8997
results += self._inner_call(
9098
**kwargs,
91-
calls=calls[start: min(start + batch_size, len(calls))],
99+
calls_with_calldata=calls_with_calldata[start: min(start + batch_size, len(calls_with_calldata))],
92100
)
93101
return results
94102

95103
if self.multicall.address is None:
96-
multicall_call = self._build_constructor_calldata(calls=calls, use_revert=use_revert)
104+
multicall_call = self._build_constructor_calldata(
105+
calls_with_calldata=calls_with_calldata,
106+
use_revert=use_revert
107+
)
97108
else:
98-
multicall_call = self._build_calldata(calls=calls)
109+
multicall_call = self._build_calldata(
110+
calls_with_calldata=calls_with_calldata
111+
)
99112
try:
100113
raw_returns = self._call_multicall(
101114
multicall_call=multicall_call,
102-
retry=len(calls) == 1
115+
retry=len(calls_with_calldata) == 1
103116
)
104117
except Exception as e:
105-
if len(calls) == 1:
118+
if len(calls_with_calldata) == 1:
106119
print(f"Multicall with single call got Exception '{repr(e)}', retrying in 1 sec")
107120
sleep(1)
108-
return self._inner_call(**kwargs, calls=calls)
121+
return self._inner_call(**kwargs, calls_with_calldata=calls_with_calldata)
109122
print(f"Multicall got Exception '{repr(e)}', splitting and retrying")
110-
left_results = self._inner_call(**kwargs, calls=calls[:len(calls) // 2])
111-
right_results = self._inner_call(**kwargs, calls=calls[len(calls) // 2:])
123+
left_results = self._inner_call(**kwargs, calls_with_calldata=calls_with_calldata[:len(calls_with_calldata) // 2])
124+
right_results = self._inner_call(**kwargs, calls_with_calldata=calls_with_calldata[len(calls_with_calldata) // 2:])
112125
return left_results + right_results
113-
results = self.decode_contract_function_results(raw_returns=raw_returns, contract_functions=calls)
114-
if len(results) == len(calls):
126+
results = self.decode_contract_function_results(raw_returns=raw_returns, contract_functions=[call for call, _ in calls_with_calldata])
127+
if len(results) == len(calls_with_calldata):
115128
return results
116129
# if not all calls were executed, recursively execute remaining calls and concatenate results
117-
return results + self._inner_call(**kwargs, calls=calls[len(results):])
130+
return results + self._inner_call(**kwargs, calls_with_calldata=calls_with_calldata[len(results):])
118131

119132
@staticmethod
120133
def calculate_expected_contract_address(sender: str, nonce: int):
@@ -125,33 +138,47 @@ def calculate_expected_contract_address(sender: str, nonce: int):
125138
@staticmethod
126139
def calculate_create_address(sender: str, nonce: int) -> str:
127140
assert len(sender) == 42
128-
sender_bytes = eth_utils.to_bytes(hexstr=sender)
141+
sender_bytes = to_bytes(hexstr=sender)
129142
raw = rlp.encode([sender_bytes, nonce])
130143
h = eth_utils.keccak(raw)
131144
address_bytes = h[12:]
132145
return eth_utils.to_checksum_address(address_bytes)
133146

134-
def _build_calldata(self, calls: list[ContractFunction]) -> ContractFunction:
147+
@staticmethod
148+
def add_calls_calldata(calls: list[ContractFunction]) -> list[tuple[ContractFunction, bytes]]:
149+
calls_with_calldata = []
150+
for call in calls:
151+
function_abi = get_abi_input_types(call.abi)
152+
assert len(function_abi) == len(call.arguments)
153+
function_args = []
154+
for aby_type, arg in zip(function_abi, call.arguments):
155+
if aby_type == "bytes":
156+
arg = to_bytes(hexstr=arg)
157+
function_args.append(arg)
158+
call_data = to_bytes(hexstr=call.selector) + eth_abi.encode(function_abi, function_args)
159+
calls_with_calldata.append((call, call_data))
160+
assert len(calls_with_calldata) == len(calls)
161+
return calls_with_calldata
162+
163+
def _build_calldata(self, calls_with_calldata: list[tuple[ContractFunction, bytes]]) -> ContractFunction:
135164
assert self.multicall.address is not None
136165

137166
if self.undeployed_contract_constructor is not None:
138167
# deploy undeployed contract first and then call the other functions
139168
contract_deployment_call = self.multicall.functions.deployContract(
140-
contractBytecode=self.undeployed_contract_constructor.data_in_transaction
169+
contractBytecode=to_bytes(hexstr=self.undeployed_contract_constructor.data_in_transaction)
141170
)
142-
calls = [contract_deployment_call] + calls
171+
contract_deployment_calldata = to_bytes(hexstr=contract_deployment_call.selector) + \
172+
eth_abi.encode(
173+
get_abi_input_types(contract_deployment_call.abi),
174+
contract_deployment_call.arguments
175+
)
176+
# contract_deployment_calldata = to_bytes(hexstr=contract_deployment_call._encode_transaction_data())
177+
calls_with_calldata = [(contract_deployment_call, contract_deployment_calldata)] + calls_with_calldata
143178

144179
encoded_calls = []
145-
for call in calls:
146-
target = call.address
147-
call_data_hex = call._encode_transaction_data()
148-
call_data = to_bytes(hexstr=call_data_hex)
149-
150-
encoded_calls.append({
151-
"target": target,
152-
"gasLimit": 100_000_000,
153-
"callData": call_data,
154-
})
180+
for call, call_data in calls_with_calldata:
181+
encoded_calls.append((call.address, 100_000_000, call_data)) # target, gasLimit, callData
155182

156183
# build multicall transaction
157184
multicall_call = self.multicall.functions.multicallWithGasLimitation(
@@ -162,20 +189,22 @@ def _build_calldata(self, calls: list[ContractFunction]) -> ContractFunction:
162189
# return multicall address and calldata
163190
return multicall_call
164191

165-
def _build_constructor_calldata(self, calls: list[ContractFunction], use_revert: bool) -> ContractConstructor:
192+
def _build_constructor_calldata(
193+
self,
194+
calls_with_calldata: list[tuple[ContractFunction, bytes]],
195+
use_revert: bool
196+
) -> ContractConstructor:
166197
assert self.multicall.address is None
167198

168199
# Encode the number of calls as the first 32 bytes
169-
number_of_calls = len(calls)
200+
number_of_calls = len(calls_with_calldata)
170201
encoded_calls = eth_abi.encode(['uint256'], [number_of_calls]).hex()
171202

172203
previous_target = None
173204
previous_call_data = None
174205

175-
for call in calls:
206+
for call, call_data in calls_with_calldata:
176207
target = call.address
177-
call_data_hex = call._encode_transaction_data()
178-
call_data = to_bytes(hexstr=call_data_hex)
179208

180209
# Determine the flags
181210
flags = 0
@@ -197,7 +226,7 @@ def _build_constructor_calldata(self, calls: list[ContractFunction], use_revert:
197226
# Encode call data length (16 bits / 2 bytes)
198227
call_data_length_encoded = eth_abi.encode(['uint16'], [len(call_data)]).hex().zfill(4)[-4:]
199228
# Encode call data (variable length)
200-
call_data_encoded = call_data_hex[2:]
229+
call_data_encoded = call_data.hex()
201230
else:
202231
call_data_length_encoded = ""
203232
call_data_encoded = ""
@@ -215,7 +244,7 @@ def _build_constructor_calldata(self, calls: list[ContractFunction], use_revert:
215244
multicall_call = self.multicall.constructor(
216245
useRevert=use_revert,
217246
contractBytecode=contract_constructor_data,
218-
encodedCalls=bytes.fromhex(encoded_calls)
247+
encodedCalls=to_bytes(hexstr=encoded_calls)
219248
)
220249

221250
return multicall_call
@@ -224,7 +253,7 @@ def _build_constructor_calldata(self, calls: list[ContractFunction], use_revert:
224253
def _decode_muilticall(multicall_result: bytes | list[tuple[bool, int, bytes]]) -> list[str | Exception]:
225254
raw_returns: list[str or Exception] = []
226255

227-
if isinstance(multicall_result, list):
256+
if isinstance(multicall_result, list) or isinstance(multicall_result, tuple):
228257
# deployed multicall
229258
for sucess, _, raw_return in multicall_result:
230259
if not sucess:
@@ -281,24 +310,39 @@ def _call_multicall(self, multicall_call: ContractConstructor | ContractFunction
281310
})
282311
else:
283312
assert isinstance(multicall_call, ContractFunction)
284-
_, multicall_result, _ = multicall_call.call({
313+
# manually encoding and decoding call because web3.py is sooooo slow...
314+
# The simple but slow version is as below:
315+
# _, multicall_result, completed_calls = multicall_call.call({
316+
# "from": self.CALLER_ADDRESS,
317+
# "nonce": 0,
318+
# "no_retry": not retry,
319+
# })
320+
321+
calldata = to_bytes(hexstr=multicall_call.selector) + \
322+
eth_abi.encode(get_abi_input_types(multicall_call.abi), multicall_call.arguments)
323+
raw_response = self.w3.eth.call({
285324
"from": self.CALLER_ADDRESS,
325+
"to": multicall_call.address,
286326
"nonce": 0,
327+
"data": calldata,
287328
"no_retry": not retry,
288329
})
330+
_, multicall_result, completed_calls = eth_abi.decode(get_abi_output_types(multicall_call.abi), raw_response)
331+
289332
if self.undeployed_contract_constructor is not None:
290333
# remove first call result as that's the deployment of the undeployed contract
291334
success, _, address_encoded = multicall_result[0]
292335
assert success, "Undeployed contract constructor reverted"
293336
assert "0x" + address_encoded[-20:].hex() == self.undeployed_contract_address.lower(), "unexpected undeployed contract address"
294337
multicall_result = multicall_result[1:]
338+
multicall_result = multicall_result[:completed_calls]
295339
except ContractLogicError as e:
296340
if not e.message.startswith("execution reverted: "):
297341
raise
298342
result_str = e.message.removeprefix("execution reverted: ")
299343
if any((char not in HEX_CHARS for char in result_str)):
300344
raise
301-
multicall_result = bytes.fromhex(result_str)
345+
multicall_result = to_bytes(hexstr=result_str)
302346

303347
if len(multicall_result) == 0:
304348
raise ValueError("No data returned from multicall")
@@ -310,8 +354,7 @@ def decode_contract_function_result(raw_return: str | Exception, contract_functi
310354
if isinstance(raw_return, Exception):
311355
return raw_return
312356
try:
313-
output_types = get_abi_output_types(contract_function.abi)
314-
result = contract_function.w3.codec.decode(output_types, raw_return)
357+
result = eth_abi.decode(get_abi_output_types(contract_function.abi), raw_return)
315358
if hasattr(result, "__len__") and len(result) == 1:
316359
result = result[0]
317360
return result

0 commit comments

Comments
 (0)