1
1
"""Account-related types for Ethereum tests."""
2
2
3
+ import json
3
4
from dataclasses import dataclass , field
4
- from typing import Dict , List , Literal , Optional , Tuple
5
+ from enum import Enum , auto
6
+ from typing import Any , Dict , List , Literal , Optional , Self , Tuple
5
7
6
8
from coincurve .keys import PrivateKey
7
9
from ethereum_types .bytes import Bytes20
8
10
from ethereum_types .numeric import U256 , Bytes32 , Uint
9
11
from pydantic import PrivateAttr
10
- from typing_extensions import Self
11
12
12
13
from ethereum_test_base_types import (
13
14
Account ,
@@ -144,12 +145,6 @@ class UnexpectedAccountError(Exception):
144
145
address : Address
145
146
account : Account | None
146
147
147
- def __init__ (self , address : Address , account : Account | None , * args ):
148
- """Initialize the exception."""
149
- super ().__init__ (args )
150
- self .address = address
151
- self .account = account
152
-
153
148
def __str__ (self ):
154
149
"""Print exception string."""
155
150
return f"unexpected account in allocation { self .address } : { self .account } "
@@ -160,25 +155,82 @@ class MissingAccountError(Exception):
160
155
161
156
address : Address
162
157
163
- def __init__ (self , address : Address , * args ):
164
- """Initialize the exception."""
165
- super ().__init__ (args )
166
- self .address = address
167
-
168
158
def __str__ (self ):
169
159
"""Print exception string."""
170
160
return f"Account missing from allocation { self .address } "
171
161
162
+ @dataclass (kw_only = True )
163
+ class CollisionError (Exception ):
164
+ """Different accounts at the same address."""
165
+
166
+ address : Address
167
+ account_1 : Account | None
168
+ account_2 : Account | None
169
+
170
+ def to_json (self ) -> Dict [str , Any ]:
171
+ """Dump to json object."""
172
+ return {
173
+ "address" : self .address .hex (),
174
+ "account_1" : self .account_1 .model_dump (mode = "json" )
175
+ if self .account_1 is not None
176
+ else None ,
177
+ "account_2" : self .account_2 .model_dump (mode = "json" )
178
+ if self .account_2 is not None
179
+ else None ,
180
+ }
181
+
182
+ @classmethod
183
+ def from_json (cls , obj : Dict [str , Any ]) -> Self :
184
+ """Parse from a json dict."""
185
+ return cls (
186
+ address = Address (obj ["address" ]),
187
+ account_1 = Account .model_validate (obj ["account_1" ])
188
+ if obj ["account_1" ] is not None
189
+ else None ,
190
+ account_2 = Account .model_validate (obj ["account_2" ])
191
+ if obj ["account_2" ] is not None
192
+ else None ,
193
+ )
194
+
195
+ def __str__ (self ) -> str :
196
+ """Print exception string."""
197
+ return (
198
+ "Overlapping key defining different accounts detected:\n "
199
+ f"{ json .dumps (self .to_json (), indent = 2 )} "
200
+ )
201
+
202
+ class KeyCollisionMode (Enum ):
203
+ """Mode for handling key collisions when merging allocations."""
204
+
205
+ ERROR = auto ()
206
+ OVERWRITE = auto ()
207
+ ALLOW_IDENTICAL_ACCOUNTS = auto ()
208
+
172
209
@classmethod
173
210
def merge (
174
- cls , alloc_1 : "Alloc" , alloc_2 : "Alloc" , allow_key_collision : bool = True
211
+ cls ,
212
+ alloc_1 : "Alloc" ,
213
+ alloc_2 : "Alloc" ,
214
+ key_collision_mode : KeyCollisionMode = KeyCollisionMode .OVERWRITE ,
175
215
) -> "Alloc" :
176
216
"""Return merged allocation of two sources."""
177
217
overlapping_keys = alloc_1 .root .keys () & alloc_2 .root .keys ()
178
- if overlapping_keys and not allow_key_collision :
179
- raise Exception (
180
- f"Overlapping keys detected: { [key .hex () for key in overlapping_keys ]} "
181
- )
218
+ if overlapping_keys :
219
+ if key_collision_mode == cls .KeyCollisionMode .ERROR :
220
+ raise Exception (
221
+ f"Overlapping keys detected: { [key .hex () for key in overlapping_keys ]} "
222
+ )
223
+ elif key_collision_mode == cls .KeyCollisionMode .ALLOW_IDENTICAL_ACCOUNTS :
224
+ # The overlapping keys must point to the exact same account
225
+ for key in overlapping_keys :
226
+ account_1 = alloc_1 [key ]
227
+ account_2 = alloc_2 [key ]
228
+ if account_1 != account_2 :
229
+ raise Alloc .CollisionError (
230
+ address = key ,
231
+ account_1 = account_1 ,
232
+ account_2 = account_2 ,
233
+ )
182
234
merged = alloc_1 .model_dump ()
183
235
184
236
for address , other_account in alloc_2 .root .items ():
@@ -267,15 +319,17 @@ def verify_post_alloc(self, got_alloc: "Alloc"):
267
319
if account is None :
268
320
# Account must not exist
269
321
if address in got_alloc .root and got_alloc .root [address ] is not None :
270
- raise Alloc .UnexpectedAccountError (address , got_alloc .root [address ])
322
+ raise Alloc .UnexpectedAccountError (
323
+ address = address , account = got_alloc .root [address ]
324
+ )
271
325
else :
272
326
if address in got_alloc .root :
273
327
got_account = got_alloc .root [address ]
274
328
assert isinstance (got_account , Account )
275
329
assert isinstance (account , Account )
276
330
account .check_alloc (address , got_account )
277
331
else :
278
- raise Alloc .MissingAccountError (address )
332
+ raise Alloc .MissingAccountError (address = address )
279
333
280
334
def deploy_contract (
281
335
self ,
0 commit comments