@@ -59,20 +59,23 @@ def __init__(
59
59
self .undeployed_contract_address = self .calculate_expected_contract_address (sender = self .CALLER_ADDRESS , nonce = 0 )
60
60
61
61
self .calls : list [ContractFunction ] = []
62
+ self .state_overwrites : list [StateOverride | None ] = []
62
63
self .undeployed_contract_constructor : Optional [ContractConstructor ] = None
63
64
64
- def add_call (self , contract_func : ContractFunction ):
65
+ def add_call (self , contract_func : ContractFunction , state_override : Optional [ StateOverride ] = None ):
65
66
self .calls .append (contract_func )
67
+ self .state_overwrites .append (state_override )
66
68
67
69
def add_undeployed_contract (self , contract_constructor : ContractConstructor ):
68
70
assert self .undeployed_contract_constructor is None , "can only add one undeployed contract"
69
71
self .undeployed_contract_constructor = contract_constructor
70
72
71
- def add_undeployed_contract_call (self , contract_func : ContractFunction ):
73
+ def add_undeployed_contract_call (self , contract_func : ContractFunction , state_override : Optional [ StateOverride ] = None ):
72
74
assert self .undeployed_contract_constructor is not None , "No undeployed contract added yet"
73
75
contract_func = copy .copy (contract_func )
74
76
contract_func .address = 0 # self.undeployed_contract_address
75
77
self .calls .append (contract_func )
78
+ self .state_overwrites .append (state_override )
76
79
77
80
def call (
78
81
self ,
@@ -99,37 +102,43 @@ def call_with_gas(
99
102
use_revert = self .w3 .revert_reason_available
100
103
101
104
calls = self .calls
105
+ state_overwrites = self .state_overwrites
102
106
calls_with_calldata = self .add_calls_calldata (calls )
103
107
104
108
return self ._inner_call (
105
109
use_revert = use_revert ,
106
110
calls_with_calldata = calls_with_calldata ,
107
111
batch_size = batch_size ,
108
- state_override = state_override
112
+ state_overwrites = state_overwrites ,
113
+ global_state_override = state_override ,
109
114
)
110
115
111
116
def _inner_call (
112
117
self ,
113
118
use_revert : bool ,
114
119
calls_with_calldata : list [tuple [ContractFunction , bytes ]],
115
120
batch_size : int ,
116
- state_override : Optional [StateOverride ] = None
121
+ state_overwrites : list [StateOverride | None ],
122
+ global_state_override : StateOverride | None = None ,
117
123
) -> tuple [list [Exception | tuple [any , ...]], list [int ]]:
124
+ assert len (calls_with_calldata ) == len (state_overwrites )
118
125
if len (calls_with_calldata ) == 0 :
119
126
return [], []
120
127
kwargs = dict (
121
128
use_revert = use_revert ,
122
129
batch_size = batch_size ,
123
- state_override = state_override ,
130
+ global_state_override = global_state_override ,
124
131
)
125
132
# make sure calls are not bigger than batch_size
126
133
if len (calls_with_calldata ) > batch_size :
127
134
results_combined = []
128
135
gas_usages_combined = []
129
136
for start in range (0 , len (calls_with_calldata ), batch_size ):
137
+ end = min (start + batch_size , len (calls_with_calldata ))
130
138
results , gas_usages = self ._inner_call (
131
139
** kwargs ,
132
- calls_with_calldata = calls_with_calldata [start : min (start + batch_size , len (calls_with_calldata ))],
140
+ calls_with_calldata = calls_with_calldata [start :end ],
141
+ state_overwrites = state_overwrites [start :end ]
133
142
)
134
143
results_combined += results
135
144
gas_usages_combined += gas_usages
@@ -146,6 +155,8 @@ def _inner_call(
146
155
)
147
156
use_revert = False
148
157
158
+ state_override = self .merge_state_overwrites ([global_state_override ] + state_overwrites )
159
+
149
160
try :
150
161
raw_returns , gas_usages = self ._call_multicall (
151
162
multicall_call = multicall_call ,
@@ -166,8 +177,17 @@ def _inner_call(
166
177
raw_returns = [e ]
167
178
gas_usages = [None ]
168
179
else :
169
- left_results , left_gas_usages = self ._inner_call (** kwargs , calls_with_calldata = calls_with_calldata [:len (calls_with_calldata ) // 2 ])
170
- right_results , right_gas_usages = self ._inner_call (** kwargs , calls_with_calldata = calls_with_calldata [len (calls_with_calldata ) // 2 :])
180
+ middle = len (calls_with_calldata ) // 2
181
+ left_results , left_gas_usages = self ._inner_call (
182
+ ** kwargs ,
183
+ calls_with_calldata = calls_with_calldata [:middle ],
184
+ state_overwrites = state_overwrites [:middle ]
185
+ )
186
+ right_results , right_gas_usages = self ._inner_call (
187
+ ** kwargs ,
188
+ calls_with_calldata = calls_with_calldata [middle :],
189
+ state_overwrites = state_overwrites [middle :]
190
+ )
171
191
return left_results + right_results , left_gas_usages + right_gas_usages
172
192
else :
173
193
if len (raw_returns ) != len (calls_with_calldata ) and len (raw_returns ) > 1 :
@@ -180,9 +200,50 @@ def _inner_call(
180
200
if len (results ) == len (calls_with_calldata ):
181
201
return results , gas_usages
182
202
# if not all calls were executed, recursively execute remaining calls and concatenate results
183
- right_results , right_gas_usages = self ._inner_call (** kwargs , calls_with_calldata = calls_with_calldata [len (results ):])
203
+ right_results , right_gas_usages = self ._inner_call (
204
+ ** kwargs ,
205
+ calls_with_calldata = calls_with_calldata [len (results ):],
206
+ state_overwrites = state_overwrites [len (results ):]
207
+ )
184
208
return results + right_results , gas_usages + right_gas_usages
185
209
210
+ @staticmethod
211
+ def merge_state_overwrites (state_overwrites : list [StateOverride | None ]) -> StateOverride | None :
212
+ if all (overwrite is None for overwrite in state_overwrites ):
213
+ return None
214
+ merged_overwrite : StateOverride = {}
215
+ for overwrites in state_overwrites :
216
+ if overwrites is None :
217
+ continue
218
+ for contract_address , overwrite in overwrites .items ():
219
+ if contract_address not in merged_overwrite :
220
+ merged_overwrite [contract_address ] = copy .deepcopy (overwrite )
221
+ continue
222
+ prev_overwrite = merged_overwrite [contract_address ]
223
+ if "balance" in overwrite :
224
+ assert "balance" not in prev_overwrite
225
+ prev_overwrite ["balance" ] = overwrite ["balance" ]
226
+ if "nonce" in overwrite :
227
+ assert "nonce" not in prev_overwrite
228
+ prev_overwrite ["nonce" ] = overwrite ["nonce" ]
229
+ if "code" in overwrite :
230
+ assert "code" not in prev_overwrite
231
+ prev_overwrite ["code" ] = overwrite ["code" ]
232
+ if "state" in overwrite :
233
+ assert "state" not in prev_overwrite
234
+ assert "stateDiff" not in prev_overwrite
235
+ prev_overwrite ["state" ] = overwrite ["state" ]
236
+ if "stateDiff" in overwrite :
237
+ assert "state" not in prev_overwrite
238
+ if "stateDiff" not in prev_overwrite :
239
+ prev_overwrite ["stateDiff" ] = copy .deepcopy (overwrite ["stateDiff" ])
240
+ else :
241
+ prev_state_diff = prev_overwrite ["stateDiff" ]
242
+ for slot , value in overwrite ["stateDiff" ].items ():
243
+ assert slot not in prev_state_diff
244
+ prev_state_diff [slot ] = value
245
+ return merged_overwrite
246
+
186
247
@staticmethod
187
248
def calculate_expected_contract_address (sender : str , nonce : int ):
188
249
undeployed_contract_runner_address = calculate_create_address (sender = sender , nonce = nonce )
0 commit comments