Skip to content

Commit a891fb9

Browse files
authored
feat(fixtures,specs,types): Verify address collisions during pre-alloc grouping phase (#1922)
* feat: detect address collisions * refactor: Better exception def
1 parent c6c3a1f commit a891fb9

File tree

3 files changed

+100
-25
lines changed

3 files changed

+100
-25
lines changed

src/ethereum_test_fixtures/pre_alloc_groups.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Pre-allocation group models for test fixture generation."""
22

3+
import json
34
from pathlib import Path
45
from typing import Any, Dict, List
56

@@ -57,10 +58,26 @@ def to_file(self, file: Path) -> None:
5758
if file.exists():
5859
with open(file, "r") as f:
5960
previous_pre_alloc_group = PreAllocGroup.model_validate_json(f.read())
60-
for account in previous_pre_alloc_group.pre:
61-
if account not in self.pre:
62-
self.pre[account] = previous_pre_alloc_group.pre[account]
63-
self.test_ids.extend(previous_pre_alloc_group.test_ids)
61+
for account in previous_pre_alloc_group.pre:
62+
existing_account = previous_pre_alloc_group.pre[account]
63+
if account not in self.pre:
64+
self.pre[account] = existing_account
65+
else:
66+
new_account = self.pre[account]
67+
if new_account != existing_account:
68+
# This procedure fails during xdist worker's pytest_sessionfinish
69+
# and is not reported to the master thread.
70+
# We signal here that the groups created contain a collision.
71+
collision_file_path = file.with_suffix(".fail")
72+
collision_exception = Alloc.CollisionError(
73+
address=account,
74+
account_1=existing_account,
75+
account_2=new_account,
76+
)
77+
with open(collision_file_path, "w") as f:
78+
f.write(json.dumps(collision_exception.to_json()))
79+
raise collision_exception
80+
self.test_ids.extend(previous_pre_alloc_group.test_ids)
6481

6582
with open(file, "w") as f:
6683
f.write(self.model_dump_json(by_alias=True, exclude_none=True, indent=2))
@@ -78,6 +95,10 @@ def __setitem__(self, key: str, value: Any):
7895
@classmethod
7996
def from_folder(cls, folder: Path) -> "PreAllocGroups":
8097
"""Create PreAllocGroups from a folder of pre-allocation files."""
98+
# First check for collision failures
99+
for fail_file in folder.glob("*.fail"):
100+
with open(fail_file) as f:
101+
raise Alloc.CollisionError.from_json(json.loads(f.read()))
81102
data = {}
82103
for file in folder.glob("*.json"):
83104
with open(file) as f:

src/ethereum_test_specs/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def update_pre_alloc_groups(
233233
group.pre = Alloc.merge(
234234
group.pre,
235235
self.pre,
236-
allow_key_collision=True,
236+
key_collision_mode=Alloc.KeyCollisionMode.ALLOW_IDENTICAL_ACCOUNTS,
237237
)
238238
group.fork = fork
239239
group.test_ids.append(str(test_id))

src/ethereum_test_types/account_types.py

Lines changed: 74 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
"""Account-related types for Ethereum tests."""
22

3+
import json
34
from dataclasses import dataclass, field
4-
from typing import Dict, List, Literal, Optional, Tuple
5+
from enum import Enum, auto
6+
from typing import Any, Dict, List, Literal, Optional, Self, Tuple
57

68
from coincurve.keys import PrivateKey
79
from ethereum_types.bytes import Bytes20
810
from ethereum_types.numeric import U256, Bytes32, Uint
911
from pydantic import PrivateAttr
10-
from typing_extensions import Self
1112

1213
from ethereum_test_base_types import (
1314
Account,
@@ -144,12 +145,6 @@ class UnexpectedAccountError(Exception):
144145
address: Address
145146
account: Account | None
146147

147-
def __init__(self, address: Address, account: Account | None, *args):
148-
"""Initialize the exception."""
149-
super().__init__(args)
150-
self.address = address
151-
self.account = account
152-
153148
def __str__(self):
154149
"""Print exception string."""
155150
return f"unexpected account in allocation {self.address}: {self.account}"
@@ -160,25 +155,82 @@ class MissingAccountError(Exception):
160155

161156
address: Address
162157

163-
def __init__(self, address: Address, *args):
164-
"""Initialize the exception."""
165-
super().__init__(args)
166-
self.address = address
167-
168158
def __str__(self):
169159
"""Print exception string."""
170160
return f"Account missing from allocation {self.address}"
171161

162+
@dataclass(kw_only=True)
163+
class CollisionError(Exception):
164+
"""Different accounts at the same address."""
165+
166+
address: Address
167+
account_1: Account | None
168+
account_2: Account | None
169+
170+
def to_json(self) -> Dict[str, Any]:
171+
"""Dump to json object."""
172+
return {
173+
"address": self.address.hex(),
174+
"account_1": self.account_1.model_dump(mode="json")
175+
if self.account_1 is not None
176+
else None,
177+
"account_2": self.account_2.model_dump(mode="json")
178+
if self.account_2 is not None
179+
else None,
180+
}
181+
182+
@classmethod
183+
def from_json(cls, obj: Dict[str, Any]) -> Self:
184+
"""Parse from a json dict."""
185+
return cls(
186+
address=Address(obj["address"]),
187+
account_1=Account.model_validate(obj["account_1"])
188+
if obj["account_1"] is not None
189+
else None,
190+
account_2=Account.model_validate(obj["account_2"])
191+
if obj["account_2"] is not None
192+
else None,
193+
)
194+
195+
def __str__(self) -> str:
196+
"""Print exception string."""
197+
return (
198+
"Overlapping key defining different accounts detected:\n"
199+
f"{json.dumps(self.to_json(), indent=2)}"
200+
)
201+
202+
class KeyCollisionMode(Enum):
203+
"""Mode for handling key collisions when merging allocations."""
204+
205+
ERROR = auto()
206+
OVERWRITE = auto()
207+
ALLOW_IDENTICAL_ACCOUNTS = auto()
208+
172209
@classmethod
173210
def merge(
174-
cls, alloc_1: "Alloc", alloc_2: "Alloc", allow_key_collision: bool = True
211+
cls,
212+
alloc_1: "Alloc",
213+
alloc_2: "Alloc",
214+
key_collision_mode: KeyCollisionMode = KeyCollisionMode.OVERWRITE,
175215
) -> "Alloc":
176216
"""Return merged allocation of two sources."""
177217
overlapping_keys = alloc_1.root.keys() & alloc_2.root.keys()
178-
if overlapping_keys and not allow_key_collision:
179-
raise Exception(
180-
f"Overlapping keys detected: {[key.hex() for key in overlapping_keys]}"
181-
)
218+
if overlapping_keys:
219+
if key_collision_mode == cls.KeyCollisionMode.ERROR:
220+
raise Exception(
221+
f"Overlapping keys detected: {[key.hex() for key in overlapping_keys]}"
222+
)
223+
elif key_collision_mode == cls.KeyCollisionMode.ALLOW_IDENTICAL_ACCOUNTS:
224+
# The overlapping keys must point to the exact same account
225+
for key in overlapping_keys:
226+
account_1 = alloc_1[key]
227+
account_2 = alloc_2[key]
228+
if account_1 != account_2:
229+
raise Alloc.CollisionError(
230+
address=key,
231+
account_1=account_1,
232+
account_2=account_2,
233+
)
182234
merged = alloc_1.model_dump()
183235

184236
for address, other_account in alloc_2.root.items():
@@ -267,15 +319,17 @@ def verify_post_alloc(self, got_alloc: "Alloc"):
267319
if account is None:
268320
# Account must not exist
269321
if address in got_alloc.root and got_alloc.root[address] is not None:
270-
raise Alloc.UnexpectedAccountError(address, got_alloc.root[address])
322+
raise Alloc.UnexpectedAccountError(
323+
address=address, account=got_alloc.root[address]
324+
)
271325
else:
272326
if address in got_alloc.root:
273327
got_account = got_alloc.root[address]
274328
assert isinstance(got_account, Account)
275329
assert isinstance(account, Account)
276330
account.check_alloc(address, got_account)
277331
else:
278-
raise Alloc.MissingAccountError(address)
332+
raise Alloc.MissingAccountError(address=address)
279333

280334
def deploy_contract(
281335
self,

0 commit comments

Comments
 (0)