Skip to content

Commit 61c67db

Browse files
committed
chore(tests): add modexp length mismatch functionality.
1 parent 35829a4 commit 61c67db

File tree

2 files changed

+79
-9
lines changed

2 files changed

+79
-9
lines changed

tests/byzantium/eip198_modexp_precompile/helpers.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,62 @@ class ModExpInput(TestParameterGroup):
2525
extra_data: Bytes = Field(default_factory=Bytes)
2626
raw_input: Bytes | None = None
2727

28+
override_base_length: int | None = None
29+
override_exponent_length: int | None = None
30+
override_modulus_length: int | None = None
31+
2832
@property
2933
def length_base(self) -> Bytes:
3034
"""Return the length of the base."""
31-
return Bytes(len(self.base).to_bytes(32, "big"))
35+
length = (
36+
self.override_base_length if self.override_base_length is not None else len(self.base)
37+
)
38+
return Bytes(length.to_bytes(32, "big"))
3239

3340
@property
3441
def length_exponent(self) -> Bytes:
3542
"""Return the length of the exponent."""
36-
return Bytes(len(self.exponent).to_bytes(32, "big"))
43+
length = (
44+
self.override_exponent_length
45+
if self.override_exponent_length is not None
46+
else len(self.exponent)
47+
)
48+
return Bytes(length.to_bytes(32, "big"))
3749

3850
@property
3951
def length_modulus(self) -> Bytes:
4052
"""Return the length of the modulus."""
41-
return Bytes(len(self.modulus).to_bytes(32, "big"))
53+
length = (
54+
self.override_modulus_length
55+
if self.override_modulus_length is not None
56+
else len(self.modulus)
57+
)
58+
return Bytes(length.to_bytes(32, "big"))
59+
60+
@property
61+
def declared_base_length(self) -> int:
62+
"""Return the declared base length as int."""
63+
return (
64+
self.override_base_length if self.override_base_length is not None else len(self.base)
65+
)
66+
67+
@property
68+
def declared_exponent_length(self) -> int:
69+
"""Return the declared exponent length as int."""
70+
return (
71+
self.override_exponent_length
72+
if self.override_exponent_length is not None
73+
else len(self.exponent)
74+
)
75+
76+
@property
77+
def declared_modulus_length(self) -> int:
78+
"""Return the declared modulus length as int."""
79+
return (
80+
self.override_modulus_length
81+
if self.override_modulus_length is not None
82+
else len(self.modulus)
83+
)
4284

4385
def __bytes__(self):
4486
"""Generate input for the MODEXP precompile."""
@@ -86,6 +128,26 @@ def from_bytes(cls, input_data: Bytes | str) -> "ModExpInput":
86128

87129
return cls(base=base, exponent=exponent, modulus=modulus, raw_input=input_data)
88130

131+
@classmethod
132+
def create_mismatch(
133+
cls,
134+
base="",
135+
exponent="",
136+
modulus="",
137+
declared_base_length=None,
138+
declared_exponent_length=None,
139+
declared_modulus_length=None,
140+
):
141+
"""Create a ModExpInput with mismatched lengths."""
142+
return cls(
143+
base=Bytes(base),
144+
exponent=Bytes(exponent),
145+
modulus=Bytes(modulus),
146+
override_base_length=declared_base_length,
147+
override_exponent_length=declared_exponent_length,
148+
override_modulus_length=declared_modulus_length,
149+
)
150+
89151

90152
class ModExpOutput(TestParameterGroup):
91153
"""

tests/osaka/eip7823_modexp_upper_bounds/test_modexp_upper_bounds.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ def precompile_gas(fork: Fork, mod_exp_input: ModExpInput) -> int:
2424
"""Calculate gas cost for the ModExp precompile and verify it matches expected gas."""
2525
spec = Spec if fork < Osaka else Spec7883
2626
calculated_gas = spec.calculate_gas_cost(
27-
len(mod_exp_input.base),
28-
len(mod_exp_input.modulus),
29-
len(mod_exp_input.exponent),
27+
mod_exp_input.declared_base_length,
28+
mod_exp_input.declared_modulus_length,
29+
mod_exp_input.declared_exponent_length,
3030
mod_exp_input.exponent,
3131
)
3232
return calculated_gas
@@ -125,6 +125,13 @@ def precompile_gas(fork: Fork, mod_exp_input: ModExpInput) -> int:
125125
),
126126
id="exp_0_base_1_mod_1025",
127127
),
128+
pytest.param(
129+
ModExpInput.create_mismatch(
130+
exponent="80",
131+
declared_exponent_length=2**64,
132+
),
133+
id="exp_2_pow_64_base_0_mod_0",
134+
),
128135
],
129136
)
130137
def test_modexp_upper_bounds(
@@ -174,10 +181,11 @@ def test_modexp_upper_bounds(
174181
protected=True,
175182
sender=sender,
176183
)
184+
177185
if (
178-
len(mod_exp_input.base) <= MAX_LENGTH_BYTES
179-
and len(mod_exp_input.exponent) <= MAX_LENGTH_BYTES
180-
and len(mod_exp_input.modulus) <= MAX_LENGTH_BYTES
186+
mod_exp_input.declared_base_length <= MAX_LENGTH_BYTES
187+
and mod_exp_input.declared_exponent_length <= MAX_LENGTH_BYTES
188+
and mod_exp_input.declared_modulus_length <= MAX_LENGTH_BYTES
181189
) or (fork < Osaka and not expensive):
182190
output = ModExpOutput(call_success=True, returned_data="0x01")
183191
else:

0 commit comments

Comments
 (0)