Skip to content

Commit 3ab3827

Browse files
committed
v0.1.38 Multicall now supports per call state overwrites, to allow also splitting state overwrites if calls get an error
1 parent 18331dc commit 3ab3827

File tree

2 files changed

+71
-10
lines changed

2 files changed

+71
-10
lines changed

IceCreamSwapWeb3/Multicall.py

Lines changed: 70 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,20 +59,23 @@ def __init__(
5959
self.undeployed_contract_address = self.calculate_expected_contract_address(sender=self.CALLER_ADDRESS, nonce=0)
6060

6161
self.calls: list[ContractFunction] = []
62+
self.state_overwrites: list[StateOverride | None] = []
6263
self.undeployed_contract_constructor: Optional[ContractConstructor] = None
6364

64-
def add_call(self, contract_func: ContractFunction):
65+
def add_call(self, contract_func: ContractFunction, state_override: Optional[StateOverride] = None):
6566
self.calls.append(contract_func)
67+
self.state_overwrites.append(state_override)
6668

6769
def add_undeployed_contract(self, contract_constructor: ContractConstructor):
6870
assert self.undeployed_contract_constructor is None, "can only add one undeployed contract"
6971
self.undeployed_contract_constructor = contract_constructor
7072

71-
def add_undeployed_contract_call(self, contract_func: ContractFunction):
73+
def add_undeployed_contract_call(self, contract_func: ContractFunction, state_override: Optional[StateOverride] = None):
7274
assert self.undeployed_contract_constructor is not None, "No undeployed contract added yet"
7375
contract_func = copy.copy(contract_func)
7476
contract_func.address = 0 # self.undeployed_contract_address
7577
self.calls.append(contract_func)
78+
self.state_overwrites.append(state_override)
7679

7780
def call(
7881
self,
@@ -99,37 +102,43 @@ def call_with_gas(
99102
use_revert = self.w3.revert_reason_available
100103

101104
calls = self.calls
105+
state_overwrites = self.state_overwrites
102106
calls_with_calldata = self.add_calls_calldata(calls)
103107

104108
return self._inner_call(
105109
use_revert=use_revert,
106110
calls_with_calldata=calls_with_calldata,
107111
batch_size=batch_size,
108-
state_override=state_override
112+
state_overwrites=state_overwrites,
113+
global_state_override=state_override,
109114
)
110115

111116
def _inner_call(
112117
self,
113118
use_revert: bool,
114119
calls_with_calldata: list[tuple[ContractFunction, bytes]],
115120
batch_size: int,
116-
state_override: Optional[StateOverride] = None
121+
state_overwrites: list[StateOverride | None],
122+
global_state_override: StateOverride | None = None,
117123
) -> tuple[list[Exception | tuple[any, ...]], list[int]]:
124+
assert len(calls_with_calldata) == len(state_overwrites)
118125
if len(calls_with_calldata) == 0:
119126
return [], []
120127
kwargs = dict(
121128
use_revert=use_revert,
122129
batch_size=batch_size,
123-
state_override=state_override,
130+
global_state_override=global_state_override,
124131
)
125132
# make sure calls are not bigger than batch_size
126133
if len(calls_with_calldata) > batch_size:
127134
results_combined = []
128135
gas_usages_combined = []
129136
for start in range(0, len(calls_with_calldata), batch_size):
137+
end = min(start + batch_size, len(calls_with_calldata))
130138
results, gas_usages = self._inner_call(
131139
**kwargs,
132-
calls_with_calldata=calls_with_calldata[start: min(start + batch_size, len(calls_with_calldata))],
140+
calls_with_calldata=calls_with_calldata[start:end],
141+
state_overwrites=state_overwrites[start:end]
133142
)
134143
results_combined += results
135144
gas_usages_combined += gas_usages
@@ -146,6 +155,8 @@ def _inner_call(
146155
)
147156
use_revert = False
148157

158+
state_override = self.merge_state_overwrites([global_state_override] + state_overwrites)
159+
149160
try:
150161
raw_returns, gas_usages = self._call_multicall(
151162
multicall_call=multicall_call,
@@ -166,8 +177,17 @@ def _inner_call(
166177
raw_returns = [e]
167178
gas_usages = [None]
168179
else:
169-
left_results, left_gas_usages = self._inner_call(**kwargs, calls_with_calldata=calls_with_calldata[:len(calls_with_calldata) // 2])
170-
right_results, right_gas_usages = self._inner_call(**kwargs, calls_with_calldata=calls_with_calldata[len(calls_with_calldata) // 2:])
180+
middle = len(calls_with_calldata) // 2
181+
left_results, left_gas_usages = self._inner_call(
182+
**kwargs,
183+
calls_with_calldata=calls_with_calldata[:middle],
184+
state_overwrites=state_overwrites[:middle]
185+
)
186+
right_results, right_gas_usages = self._inner_call(
187+
**kwargs,
188+
calls_with_calldata=calls_with_calldata[middle:],
189+
state_overwrites=state_overwrites[middle:]
190+
)
171191
return left_results + right_results, left_gas_usages + right_gas_usages
172192
else:
173193
if len(raw_returns) != len(calls_with_calldata) and len(raw_returns) > 1:
@@ -180,9 +200,50 @@ def _inner_call(
180200
if len(results) == len(calls_with_calldata):
181201
return results, gas_usages
182202
# if not all calls were executed, recursively execute remaining calls and concatenate results
183-
right_results, right_gas_usages = self._inner_call(**kwargs, calls_with_calldata=calls_with_calldata[len(results):])
203+
right_results, right_gas_usages = self._inner_call(
204+
**kwargs,
205+
calls_with_calldata=calls_with_calldata[len(results):],
206+
state_overwrites=state_overwrites[len(results):]
207+
)
184208
return results + right_results, gas_usages + right_gas_usages
185209

210+
@staticmethod
211+
def merge_state_overwrites(state_overwrites: list[StateOverride | None]) -> StateOverride | None:
212+
if all(overwrite is None for overwrite in state_overwrites):
213+
return None
214+
merged_overwrite: StateOverride = {}
215+
for overwrites in state_overwrites:
216+
if overwrites is None:
217+
continue
218+
for contract_address, overwrite in overwrites.items():
219+
if contract_address not in merged_overwrite:
220+
merged_overwrite[contract_address] = copy.deepcopy(overwrite)
221+
continue
222+
prev_overwrite = merged_overwrite[contract_address]
223+
if "balance" in overwrite:
224+
assert "balance" not in prev_overwrite
225+
prev_overwrite["balance"] = overwrite["balance"]
226+
if "nonce" in overwrite:
227+
assert "nonce" not in prev_overwrite
228+
prev_overwrite["nonce"] = overwrite["nonce"]
229+
if "code" in overwrite:
230+
assert "code" not in prev_overwrite
231+
prev_overwrite["code"] = overwrite["code"]
232+
if "state" in overwrite:
233+
assert "state" not in prev_overwrite
234+
assert "stateDiff" not in prev_overwrite
235+
prev_overwrite["state"] = overwrite["state"]
236+
if "stateDiff" in overwrite:
237+
assert "state" not in prev_overwrite
238+
if "stateDiff" not in prev_overwrite:
239+
prev_overwrite["stateDiff"] = copy.deepcopy(overwrite["stateDiff"])
240+
else:
241+
prev_state_diff = prev_overwrite["stateDiff"]
242+
for slot, value in overwrite["stateDiff"].items():
243+
assert slot not in prev_state_diff
244+
prev_state_diff[slot] = value
245+
return merged_overwrite
246+
186247
@staticmethod
187248
def calculate_expected_contract_address(sender: str, nonce: int):
188249
undeployed_contract_runner_address = calculate_create_address(sender=sender, nonce=nonce)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from setuptools import setup, find_packages
22

3-
VERSION = '0.1.37'
3+
VERSION = '0.1.38'
44
DESCRIPTION = 'IceCreamSwap Web3.py wrapper'
55
LONG_DESCRIPTION = 'IceCreamSwap Web3.py wrapper with automatic retries, multicall and other advanced functionality'
66

0 commit comments

Comments
 (0)