Skip to content

Commit c8ba2ee

Browse files
committed
initial implementation of multicall. Both with and without deployed multicall contract
1 parent cac5a74 commit c8ba2ee

File tree

7 files changed

+1329
-0
lines changed

7 files changed

+1329
-0
lines changed

IceCreamSwapWeb3/Multicall.py

Lines changed: 359 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,359 @@
1+
import copy
2+
from time import sleep
3+
from typing import Optional
4+
5+
import eth_abi
6+
import eth_utils
7+
import rlp
8+
from eth_utils import to_checksum_address, to_bytes
9+
from web3._utils.abi import get_abi_output_types
10+
from web3.contract.contract import ContractFunction, ContractConstructor
11+
from web3.exceptions import ContractLogicError
12+
13+
from IceCreamSwapWeb3 import Web3Advanced
14+
15+
16+
# load multicall abi
17+
with open("abi/Multicall.abi") as f:
18+
MULTICALL_ABI = f.read()
19+
20+
# load undeployed multicall abi and bytecode
21+
with open("abi/UndeployedMulticall.abi") as f:
22+
UNDEPLOYED_MULTICALL_ABI = f.read()
23+
with open("bytecode/UndeployedMulticall.bytecode") as f:
24+
UNDEPLOYED_MULTICALL_BYTECODE = f.read()
25+
26+
# allowed chars in HEX string
27+
HEX_CHARS = set("0123456789abcdef")
28+
29+
30+
class MultiCall:
31+
CALLER_ADDRESS = "0x0000000000000000000000000000000000000123"
32+
33+
MULTICALL_DEPLOYMENTS: dict[int, str] = {
34+
1116: "0x2C310a21E21a3eaDF4e53E1118aeD4614c51B576"
35+
}
36+
37+
@classmethod
38+
def register_multicall_contract(cls, chain_id: int, contract_address: str):
39+
cls.MULTICALL_DEPLOYMENTS[chain_id] = to_checksum_address(contract_address)
40+
41+
def __init__(
42+
self,
43+
w3: Web3Advanced
44+
):
45+
self.w3 = copy.deepcopy(w3)
46+
self.chain_id = self.w3.eth.chain_id
47+
48+
if self.chain_id in self.MULTICALL_DEPLOYMENTS:
49+
self.multicall = self.w3.eth.contract(
50+
abi=MULTICALL_ABI,
51+
address=to_checksum_address(self.MULTICALL_DEPLOYMENTS[self.chain_id])
52+
)
53+
self.undeployed_contract_address = self.calculate_create_address(sender=self.multicall.address, nonce=1)
54+
else:
55+
self.multicall = self.w3.eth.contract(abi=UNDEPLOYED_MULTICALL_ABI, bytecode=UNDEPLOYED_MULTICALL_BYTECODE)
56+
self.undeployed_contract_address = self.calculate_expected_contract_address(sender=self.CALLER_ADDRESS, nonce=0)
57+
58+
self.calls: list[ContractFunction] = []
59+
self.undeployed_contract_constructor: Optional[ContractConstructor] = None
60+
61+
def add_call(self, contract_func: ContractFunction):
62+
self.calls.append(contract_func)
63+
64+
def add_undeployed_contract(self, contract_constructor: ContractConstructor):
65+
assert self.undeployed_contract_constructor is None, "can only add one undeployed contract"
66+
self.undeployed_contract_constructor = contract_constructor
67+
68+
def add_undeployed_contract_call(self, contract_func: ContractFunction):
69+
assert self.undeployed_contract_constructor is not None, "No undeployed contract added yet"
70+
contract_func = copy.copy(contract_func)
71+
contract_func.address = self.undeployed_contract_address
72+
self.calls.append(contract_func)
73+
74+
def call(self, use_revert: Optional[bool] = None, batch_size: int = 1_000):
75+
if use_revert is None:
76+
use_revert = self.w3.revert_reason_available
77+
78+
return self._inner_call(use_revert=use_revert, calls=self.calls, batch_size=batch_size)
79+
80+
def _inner_call(self, use_revert: bool, calls: list[ContractFunction], batch_size: int):
81+
# make sure calls are not bigger than batch_size
82+
if len(calls) > batch_size:
83+
results = []
84+
for start in range(0, len(calls), batch_size):
85+
results += self._inner_call(
86+
use_revert=use_revert,
87+
calls=calls[start: min(start + batch_size, len(calls))],
88+
batch_size=batch_size
89+
)
90+
return results
91+
92+
if self.multicall.address is None:
93+
multicall_call = self._build_constructor_calldata(calls=calls, use_revert=use_revert)
94+
else:
95+
multicall_call = self._build_calldata(calls=calls)
96+
try:
97+
raw_returns = self._call_multicall(
98+
multicall_call=multicall_call,
99+
retry=len(calls) == 1
100+
)
101+
except Exception as e:
102+
if len(calls) == 1:
103+
print(f"Multicall with single call got Exception '{repr(e)}', retrying in 1 sec")
104+
sleep(1)
105+
return self._inner_call(use_revert=use_revert, calls=calls)
106+
print(f"Multicall got Exception '{repr(e)}', splitting and retrying")
107+
left_results = self._inner_call(use_revert=use_revert, calls=calls[:len(calls) // 2])
108+
right_results = self._inner_call(use_revert=use_revert, calls=calls[len(calls) // 2:])
109+
return left_results + right_results
110+
results = self.decode_contract_function_results(raw_returns=raw_returns, contract_functions=calls)
111+
if len(results) == len(calls):
112+
return results
113+
# if not all calls were executed, recursively execute remaining calls and concatenate results
114+
return results + self._inner_call(use_revert=use_revert, calls=calls[len(results):])
115+
116+
@staticmethod
117+
def calculate_expected_contract_address(sender: str, nonce: int):
118+
undeployed_contract_runner_address = MultiCall.calculate_create_address(sender=sender, nonce=nonce)
119+
contract_address = MultiCall.calculate_create_address(sender=undeployed_contract_runner_address, nonce=1)
120+
return contract_address
121+
122+
@staticmethod
123+
def calculate_create_address(sender: str, nonce: int) -> str:
124+
assert len(sender) == 42
125+
sender_bytes = eth_utils.to_bytes(hexstr=sender)
126+
raw = rlp.encode([sender_bytes, nonce])
127+
h = eth_utils.keccak(raw)
128+
address_bytes = h[12:]
129+
return eth_utils.to_checksum_address(address_bytes)
130+
131+
def _build_calldata(self, calls: list[ContractFunction]) -> ContractFunction:
132+
assert self.multicall.address is not None
133+
134+
if self.undeployed_contract_constructor is not None:
135+
# deploy undeployed contract first and then call the other functions
136+
contract_deployment_call = self.multicall.functions.deployContract(
137+
contractBytecode=self.undeployed_contract_constructor.data_in_transaction
138+
)
139+
calls = [contract_deployment_call] + calls
140+
141+
encoded_calls = []
142+
for call in calls:
143+
target = call.address
144+
call_data_hex = call._encode_transaction_data()
145+
call_data = to_bytes(hexstr=call_data_hex)
146+
147+
encoded_calls.append({
148+
"target": target,
149+
"gasLimit": 100_000_000,
150+
"callData": call_data,
151+
})
152+
153+
# build multicall transaction
154+
multicall_call = self.multicall.functions.multicallWithGasLimitation(
155+
calls=encoded_calls,
156+
gasBuffer=10_000_000,
157+
)
158+
159+
# return multicall address and calldata
160+
return multicall_call
161+
162+
def _build_constructor_calldata(self, calls: list[ContractFunction], use_revert: bool) -> ContractConstructor:
163+
assert self.multicall.address is None
164+
165+
# Encode the number of calls as the first 32 bytes
166+
number_of_calls = len(calls)
167+
encoded_calls = eth_abi.encode(['uint256'], [number_of_calls]).hex()
168+
169+
previous_target = None
170+
previous_call_data = None
171+
172+
for call in calls:
173+
target = call.address
174+
call_data_hex = call._encode_transaction_data()
175+
call_data = to_bytes(hexstr=call_data_hex)
176+
177+
# Determine the flags
178+
flags = 0
179+
if target == previous_target:
180+
flags |= 1 # Set bit 0 if target is the same as previous
181+
if call_data == previous_call_data:
182+
flags |= 2 # Set bit 1 if calldata is the same as previous
183+
184+
# Encode the flag byte (1 byte)
185+
flags_encoded = format(flags, '02x')
186+
187+
if flags & 1 == 0: # If target is different
188+
# Encode target address (20 bytes, padded to 32 bytes)
189+
target_encoded = eth_abi.encode(['address'], [target]).hex()[24:] # remove leading zeros
190+
else:
191+
target_encoded = ""
192+
193+
if flags & 2 == 0: # If calldata is different
194+
# Encode call data length (16 bits / 2 bytes)
195+
call_data_length_encoded = eth_abi.encode(['uint16'], [len(call_data)]).hex().zfill(4)[-4:]
196+
# Encode call data (variable length)
197+
call_data_encoded = call_data_hex[2:]
198+
else:
199+
call_data_length_encoded = ""
200+
call_data_encoded = ""
201+
202+
encoded_calls += flags_encoded + target_encoded + call_data_length_encoded + call_data_encoded
203+
204+
# Update previous values
205+
previous_target = target
206+
previous_call_data = call_data
207+
208+
# build multicall transaction
209+
contract_constructor_data = bytes()
210+
if self.undeployed_contract_constructor is not None:
211+
contract_constructor_data = self.undeployed_contract_constructor.data_in_transaction
212+
multicall_call = self.multicall.constructor(
213+
useRevert=use_revert,
214+
contractBytecode=contract_constructor_data,
215+
encodedCalls=bytes.fromhex(encoded_calls)
216+
)
217+
218+
return multicall_call
219+
220+
@staticmethod
221+
def _decode_muilticall(multicall_result: bytes | list[tuple[bool, int, bytes]]) -> list[str | Exception]:
222+
raw_returns: list[str or Exception] = []
223+
224+
if isinstance(multicall_result, list):
225+
# deployed multicall
226+
for sucess, _, raw_return in multicall_result:
227+
if not sucess:
228+
decoded = MultiCall.get_revert_reason(raw_return)
229+
raw_return = ContractLogicError(f"execution reverted: {decoded}")
230+
raw_returns.append(raw_return)
231+
return raw_returns
232+
233+
# undeployed multicall
234+
# decode returned data into segments
235+
multicall_result_copy = multicall_result[:]
236+
raw_returns_encoded = []
237+
while len(multicall_result_copy) != 0:
238+
data_len = int.from_bytes(multicall_result_copy[:2], byteorder='big')
239+
raw_returns_encoded.append(multicall_result_copy[2:data_len+2])
240+
multicall_result_copy = multicall_result_copy[data_len+2:]
241+
242+
# decode returned data for each call
243+
for raw_return_encoded in raw_returns_encoded:
244+
try:
245+
# we are using packed encoding to decrease size of return data, if not we could have used
246+
# success, raw_return = eth_abi.decode(['bool', 'bytes'], raw_return_encoded)
247+
success = raw_return_encoded[0] == 1
248+
raw_return = raw_return_encoded[1:]
249+
if not success:
250+
decoded = MultiCall.get_revert_reason(raw_return)
251+
raw_return = ContractLogicError(f"execution reverted: {decoded}")
252+
except Exception as e:
253+
raw_return = e
254+
raw_returns.append(raw_return)
255+
return raw_returns
256+
257+
@staticmethod
258+
def get_revert_reason(revert_bytes: bytes) -> str:
259+
if len(revert_bytes) == 0:
260+
return "unknown"
261+
else:
262+
# first 4 bytes of revert code should be function selector for function Error(string)
263+
revert_bytes = revert_bytes[4:]
264+
try:
265+
return eth_abi.decode(['string'], revert_bytes)
266+
except Exception:
267+
return revert_bytes
268+
269+
def _call_multicall(self, multicall_call: ContractConstructor | ContractFunction, retry: bool = False):
270+
# call transaction
271+
try:
272+
if isinstance(multicall_call, ContractConstructor):
273+
multicall_result = self.w3.eth.call({
274+
"from": self.CALLER_ADDRESS,
275+
"nonce": 0,
276+
"data": multicall_call.data_in_transaction,
277+
"no_retry": not retry,
278+
})
279+
else:
280+
assert isinstance(multicall_call, ContractFunction)
281+
_, multicall_result, _ = multicall_call.call({
282+
"from": self.CALLER_ADDRESS,
283+
"nonce": 0,
284+
"no_retry": not retry,
285+
})
286+
if self.undeployed_contract_constructor is not None:
287+
# remove first call result as that's the deployment of the undeployed contract
288+
success, _, address_encoded = multicall_result[0]
289+
assert success, "Undeployed contract constructor reverted"
290+
assert "0x" + address_encoded[-20:].hex() == self.undeployed_contract_address.lower(), "unexpected undeployed contract address"
291+
multicall_result = multicall_result[1:]
292+
except ContractLogicError as e:
293+
if not e.message.startswith("execution reverted: "):
294+
raise
295+
result_str = e.message.removeprefix("execution reverted: ")
296+
if any((char not in HEX_CHARS for char in result_str)):
297+
raise
298+
multicall_result = bytes.fromhex(result_str)
299+
300+
if len(multicall_result) == 0:
301+
raise ValueError("No data returned from multicall")
302+
303+
return self._decode_muilticall(multicall_result)
304+
305+
@staticmethod
306+
def decode_contract_function_result(raw_return: str | Exception, contract_function: ContractFunction):
307+
if isinstance(raw_return, Exception):
308+
return raw_return
309+
try:
310+
output_types = get_abi_output_types(contract_function.abi)
311+
result = contract_function.w3.codec.decode(output_types, raw_return)
312+
if hasattr(result, "__len__") and len(result) == 1:
313+
result = result[0]
314+
return result
315+
except Exception as e:
316+
return e
317+
318+
@staticmethod
319+
def decode_contract_function_results(raw_returns: list[str | Exception], contract_functions: list[ContractFunction]):
320+
return [MultiCall.decode_contract_function_result(raw_return, contract_function) for raw_return, contract_function in zip(raw_returns, contract_functions)]
321+
322+
323+
def main(
324+
node_url="https://rpc-core.icecreamswap.com",
325+
usdt_address=to_checksum_address("0x900101d06A7426441Ae63e9AB3B9b0F63Be145F1"),
326+
):
327+
w3 = Web3Advanced(node_url=node_url)
328+
multicall = MultiCall(w3=w3)
329+
330+
with open("abi/Counter.abi") as f:
331+
counter_contract_abi = f.read()
332+
with open("bytecode/Counter.bytecode") as f:
333+
counter_contract_bytecode = f.read()
334+
with open("abi/ERC20.abi") as f:
335+
erc20_abi = f.read()
336+
337+
counter_contract = w3.eth.contract(bytecode=counter_contract_bytecode, abi=counter_contract_abi)
338+
usdt_contract = w3.eth.contract(address=usdt_address, abi=erc20_abi)
339+
340+
# calling an undeployed contract
341+
# '''
342+
multicall.add_undeployed_contract(counter_contract.constructor(initialCounter=13))
343+
multicall.add_undeployed_contract_call(counter_contract.functions.counter())
344+
multicall.add_undeployed_contract_call(counter_contract.functions.updateCounter(newCounter=7))
345+
multicall.add_undeployed_contract_call(counter_contract.functions.counter())
346+
# '''
347+
348+
# '''
349+
for _ in range(10_000):
350+
# calling a deployed contract
351+
multicall.add_call(usdt_contract.functions.decimals())
352+
# '''
353+
354+
multicall_result = multicall.call()
355+
print(multicall_result)
356+
357+
358+
if __name__ == "__main__":
359+
main()

0 commit comments

Comments
 (0)