11"""Account-related types for Ethereum tests."""
22
3+ import json
34from dataclasses import dataclass , field
4- from typing import Dict , List , Literal , Optional , Tuple
5+ from enum import Enum , auto
6+ from typing import Any , Dict , List , Literal , Optional , Self , Tuple
57
68from coincurve .keys import PrivateKey
79from ethereum_types .bytes import Bytes20
810from ethereum_types .numeric import U256 , Bytes32 , Uint
911from pydantic import PrivateAttr
10- from typing_extensions import Self
1112
1213from ethereum_test_base_types import (
1314 Account ,
@@ -144,12 +145,6 @@ class UnexpectedAccountError(Exception):
144145 address : Address
145146 account : Account | None
146147
147- def __init__ (self , address : Address , account : Account | None , * args ):
148- """Initialize the exception."""
149- super ().__init__ (args )
150- self .address = address
151- self .account = account
152-
153148 def __str__ (self ):
154149 """Print exception string."""
155150 return f"unexpected account in allocation { self .address } : { self .account } "
@@ -160,25 +155,82 @@ class MissingAccountError(Exception):
160155
161156 address : Address
162157
163- def __init__ (self , address : Address , * args ):
164- """Initialize the exception."""
165- super ().__init__ (args )
166- self .address = address
167-
168158 def __str__ (self ):
169159 """Print exception string."""
170160 return f"Account missing from allocation { self .address } "
171161
162+ @dataclass (kw_only = True )
163+ class CollisionError (Exception ):
164+ """Different accounts at the same address."""
165+
166+ address : Address
167+ account_1 : Account | None
168+ account_2 : Account | None
169+
170+ def to_json (self ) -> Dict [str , Any ]:
171+ """Dump to json object."""
172+ return {
173+ "address" : self .address .hex (),
174+ "account_1" : self .account_1 .model_dump (mode = "json" )
175+ if self .account_1 is not None
176+ else None ,
177+ "account_2" : self .account_2 .model_dump (mode = "json" )
178+ if self .account_2 is not None
179+ else None ,
180+ }
181+
182+ @classmethod
183+ def from_json (cls , obj : Dict [str , Any ]) -> Self :
184+ """Parse from a json dict."""
185+ return cls (
186+ address = Address (obj ["address" ]),
187+ account_1 = Account .model_validate (obj ["account_1" ])
188+ if obj ["account_1" ] is not None
189+ else None ,
190+ account_2 = Account .model_validate (obj ["account_2" ])
191+ if obj ["account_2" ] is not None
192+ else None ,
193+ )
194+
195+ def __str__ (self ) -> str :
196+ """Print exception string."""
197+ return (
198+ "Overlapping key defining different accounts detected:\n "
199+ f"{ json .dumps (self .to_json (), indent = 2 )} "
200+ )
201+
202+ class KeyCollisionMode (Enum ):
203+ """Mode for handling key collisions when merging allocations."""
204+
205+ ERROR = auto ()
206+ OVERWRITE = auto ()
207+ ALLOW_IDENTICAL_ACCOUNTS = auto ()
208+
172209 @classmethod
173210 def merge (
174- cls , alloc_1 : "Alloc" , alloc_2 : "Alloc" , allow_key_collision : bool = True
211+ cls ,
212+ alloc_1 : "Alloc" ,
213+ alloc_2 : "Alloc" ,
214+ key_collision_mode : KeyCollisionMode = KeyCollisionMode .OVERWRITE ,
175215 ) -> "Alloc" :
176216 """Return merged allocation of two sources."""
177217 overlapping_keys = alloc_1 .root .keys () & alloc_2 .root .keys ()
178- if overlapping_keys and not allow_key_collision :
179- raise Exception (
180- f"Overlapping keys detected: { [key .hex () for key in overlapping_keys ]} "
181- )
218+ if overlapping_keys :
219+ if key_collision_mode == cls .KeyCollisionMode .ERROR :
220+ raise Exception (
221+ f"Overlapping keys detected: { [key .hex () for key in overlapping_keys ]} "
222+ )
223+ elif key_collision_mode == cls .KeyCollisionMode .ALLOW_IDENTICAL_ACCOUNTS :
224+ # The overlapping keys must point to the exact same account
225+ for key in overlapping_keys :
226+ account_1 = alloc_1 [key ]
227+ account_2 = alloc_2 [key ]
228+ if account_1 != account_2 :
229+ raise Alloc .CollisionError (
230+ address = key ,
231+ account_1 = account_1 ,
232+ account_2 = account_2 ,
233+ )
182234 merged = alloc_1 .model_dump ()
183235
184236 for address , other_account in alloc_2 .root .items ():
@@ -267,15 +319,17 @@ def verify_post_alloc(self, got_alloc: "Alloc"):
267319 if account is None :
268320 # Account must not exist
269321 if address in got_alloc .root and got_alloc .root [address ] is not None :
270- raise Alloc .UnexpectedAccountError (address , got_alloc .root [address ])
322+ raise Alloc .UnexpectedAccountError (
323+ address = address , account = got_alloc .root [address ]
324+ )
271325 else :
272326 if address in got_alloc .root :
273327 got_account = got_alloc .root [address ]
274328 assert isinstance (got_account , Account )
275329 assert isinstance (account , Account )
276330 account .check_alloc (address , got_account )
277331 else :
278- raise Alloc .MissingAccountError (address )
332+ raise Alloc .MissingAccountError (address = address )
279333
280334 def deploy_contract (
281335 self ,
0 commit comments