Skip to content

Commit d30b8ec

Browse files
committed
chore(tests): add modexp length mismatch functionality.
1 parent dd5efdd commit d30b8ec

File tree

4 files changed

+84
-26
lines changed

4 files changed

+84
-26
lines changed

tests/byzantium/eip198_modexp_precompile/helpers.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Helper functions for the EIP-198 ModExp precompile tests."""
22

3+
from typing import Tuple
4+
35
from pydantic import Field
46

57
from ethereum_test_tools import Bytes, TestParameterGroup
@@ -24,21 +26,37 @@ class ModExpInput(TestParameterGroup):
2426
modulus: Bytes
2527
extra_data: Bytes = Field(default_factory=Bytes)
2628
raw_input: Bytes | None = None
29+
declared_base_length: int | None = None
30+
declared_exponent_length: int | None = None
31+
declared_modulus_length: int | None = None
2732

2833
@property
2934
def length_base(self) -> Bytes:
3035
"""Return the length of the base."""
31-
return Bytes(len(self.base).to_bytes(32, "big"))
36+
length = (
37+
self.declared_base_length if self.declared_base_length is not None else len(self.base)
38+
)
39+
return Bytes(length.to_bytes(32, "big"))
3240

3341
@property
3442
def length_exponent(self) -> Bytes:
3543
"""Return the length of the exponent."""
36-
return Bytes(len(self.exponent).to_bytes(32, "big"))
44+
length = (
45+
self.declared_exponent_length
46+
if self.declared_exponent_length is not None
47+
else len(self.exponent)
48+
)
49+
return Bytes(length.to_bytes(32, "big"))
3750

3851
@property
3952
def length_modulus(self) -> Bytes:
4053
"""Return the length of the modulus."""
41-
return Bytes(len(self.modulus).to_bytes(32, "big"))
54+
length = (
55+
self.declared_modulus_length
56+
if self.declared_modulus_length is not None
57+
else len(self.modulus)
58+
)
59+
return Bytes(length.to_bytes(32, "big"))
4260

4361
def __bytes__(self):
4462
"""Generate input for the MODEXP precompile."""
@@ -86,6 +104,30 @@ def from_bytes(cls, input_data: Bytes | str) -> "ModExpInput":
86104

87105
return cls(base=base, exponent=exponent, modulus=modulus, raw_input=input_data)
88106

107+
def get_declared_lengths(self) -> Tuple[int, int, int]:
108+
"""Extract declared lengths from the raw input bytes."""
109+
raw = self.raw_input if self.raw_input is not None else bytes(self)
110+
if len(raw) < 96:
111+
raw = raw.ljust(96, b"\0")
112+
base_length = int.from_bytes(raw[0:32], byteorder="big")
113+
exponent_length = int.from_bytes(raw[32:64], byteorder="big")
114+
modulus_length = int.from_bytes(raw[64:96], byteorder="big")
115+
return base_length, exponent_length, modulus_length
116+
117+
def get_exponent_head(self) -> int:
118+
"""Get the first 32 bytes of the exponent as an integer."""
119+
raw = self.raw_input if self.raw_input is not None else bytes(self)
120+
base_length, exponent_length, _ = self.get_declared_lengths()
121+
exp_start = 96 + base_length
122+
123+
# Extract up to 32 bytes of exponent data
124+
exp_head_bytes = raw[exp_start : exp_start + min(32, exponent_length)]
125+
126+
# Pad with zeros if less than 32 bytes
127+
exp_head_bytes = exp_head_bytes.rjust(32, b"\0")
128+
129+
return int.from_bytes(exp_head_bytes[:32], byteorder="big")
130+
89131

90132
class ModExpOutput(TestParameterGroup):
91133
"""

tests/osaka/eip7823_modexp_upper_bounds/test_modexp_upper_bounds.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77

88
from ethereum_test_forks import Fork, Osaka
9-
from ethereum_test_tools import Account, Alloc, Environment, StateTestFiller, Transaction
9+
from ethereum_test_tools import Account, Alloc, Bytes, Environment, StateTestFiller, Transaction
1010
from ethereum_test_tools.vm.opcode import Opcodes as Op
1111

1212
from ...byzantium.eip198_modexp_precompile.helpers import ModExpInput, ModExpOutput
@@ -23,12 +23,7 @@
2323
def precompile_gas(fork: Fork, mod_exp_input: ModExpInput) -> int:
2424
"""Calculate gas cost for the ModExp precompile and verify it matches expected gas."""
2525
spec = Spec if fork < Osaka else Spec7883
26-
calculated_gas = spec.calculate_gas_cost(
27-
len(mod_exp_input.base),
28-
len(mod_exp_input.modulus),
29-
len(mod_exp_input.exponent),
30-
mod_exp_input.exponent,
31-
)
26+
calculated_gas = spec.calculate_gas_cost(mod_exp_input)
3227
return calculated_gas
3328

3429

@@ -125,6 +120,15 @@ def precompile_gas(fork: Fork, mod_exp_input: ModExpInput) -> int:
125120
),
126121
id="exp_0_base_1_mod_1025",
127122
),
123+
pytest.param(
124+
ModExpInput(
125+
base=b"",
126+
exponent=Bytes("80"),
127+
modulus=b"",
128+
declared_exponent_length=2**64,
129+
),
130+
id="exp_2_pow_64_base_0_mod_0",
131+
),
128132
],
129133
)
130134
def test_modexp_upper_bounds(
@@ -174,10 +178,11 @@ def test_modexp_upper_bounds(
174178
protected=True,
175179
sender=sender,
176180
)
181+
base_length, exp_length, mod_length = mod_exp_input.get_declared_lengths()
177182
if (
178-
len(mod_exp_input.base) <= MAX_LENGTH_BYTES
179-
and len(mod_exp_input.exponent) <= MAX_LENGTH_BYTES
180-
and len(mod_exp_input.modulus) <= MAX_LENGTH_BYTES
183+
base_length <= MAX_LENGTH_BYTES
184+
and exp_length <= MAX_LENGTH_BYTES
185+
and mod_length <= MAX_LENGTH_BYTES
181186
) or (fork < Osaka and not expensive):
182187
output = ModExpOutput(call_success=True, returned_data="0x01")
183188
else:

tests/osaka/eip7883_modexp_gas_increase/conftest.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,7 @@ def precompile_gas(
133133
"""Calculate gas cost for the ModExp precompile and verify it matches expected gas."""
134134
spec = Spec if fork < Osaka else Spec7883
135135
try:
136-
calculated_gas = spec.calculate_gas_cost(
137-
len(modexp_input.base),
138-
len(modexp_input.modulus),
139-
len(modexp_input.exponent),
140-
modexp_input.exponent,
141-
)
136+
calculated_gas = spec.calculate_gas_cost(modexp_input)
142137
if gas_old is not None and gas_new is not None:
143138
expected_gas = gas_old if fork < Osaka else gas_new
144139
assert calculated_gas == expected_gas, (
@@ -150,8 +145,9 @@ def precompile_gas(
150145
f"({int.from_bytes(modexp_input.exponent, byteorder='big')})"
151146
)
152147
return calculated_gas
153-
except Exception as e:
154-
print(f"Warning: Error calculating gas, using minimum: {e}")
148+
except Exception:
149+
# Used for `test_modexp_invalid_inputs` we expect the call to not succeed.
150+
# Return is for completeness.
155151
return 500 if fork >= Osaka else 200
156152

157153

tests/osaka/eip7883_modexp_gas_increase/spec.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,29 @@ def calculate_iteration_count(cls, exponent_length: int, exponent: bytes) -> int
7979
return max(iteration_count, 1)
8080

8181
@classmethod
82-
def calculate_gas_cost(
83-
cls, base_length: int, modulus_length: int, exponent_length: int, exponent: bytes
84-
) -> int:
85-
"""Calculate the ModExp gas cost according to EIP-2565 specification."""
82+
def calculate_gas_cost(cls, modexp_input: ModExpInput) -> int:
83+
"""
84+
Calculate the ModExp gas cost according to EIP-2565 specification, overridden by the
85+
constants within `Spec7883` when calculating for the EIP-7883 specification.
86+
87+
This handles length mismatch cases by using declared lengths from the raw input
88+
and only the first 32 bytes of exponent data for iteration calculation.
89+
"""
90+
base_length, exponent_length, modulus_length = modexp_input.get_declared_lengths()
91+
exponent_head = modexp_input.get_exponent_head()
8692
multiplication_complexity = cls.calculate_multiplication_complexity(
8793
base_length, modulus_length
8894
)
89-
iteration_count = cls.calculate_iteration_count(exponent_length, exponent)
95+
if exponent_length <= cls.EXPONENT_THRESHOLD and exponent_head == 0:
96+
iteration_count = 0
97+
elif exponent_length <= cls.EXPONENT_THRESHOLD:
98+
iteration_count = exponent_head.bit_length() - 1 if exponent_head > 0 else 0
99+
else:
100+
# For large exponents: length_part + bits from first 32 bytes
101+
length_part = cls.EXPONENT_BYTE_MULTIPLIER * (exponent_length - 32)
102+
bits_part = exponent_head.bit_length() - 1 if exponent_head > 0 else 0
103+
iteration_count = length_part + bits_part
104+
iteration_count = max(iteration_count, 1)
90105
return max(cls.MIN_GAS, (multiplication_complexity * iteration_count // cls.GAS_DIVISOR))
91106

92107

0 commit comments

Comments
 (0)