From d30b8ec0fc1b2dc80f60801904acddbb43cac735 Mon Sep 17 00:00:00 2001 From: spencer-tb Date: Fri, 29 Aug 2025 19:56:30 +0000 Subject: [PATCH] chore(tests): add modexp length mismatch functionality. --- .../eip198_modexp_precompile/helpers.py | 48 +++++++++++++++++-- .../test_modexp_upper_bounds.py | 25 ++++++---- .../eip7883_modexp_gas_increase/conftest.py | 12 ++--- .../osaka/eip7883_modexp_gas_increase/spec.py | 25 ++++++++-- 4 files changed, 84 insertions(+), 26 deletions(-) diff --git a/tests/byzantium/eip198_modexp_precompile/helpers.py b/tests/byzantium/eip198_modexp_precompile/helpers.py index 78fdc56fbad..9cc44fd25a1 100644 --- a/tests/byzantium/eip198_modexp_precompile/helpers.py +++ b/tests/byzantium/eip198_modexp_precompile/helpers.py @@ -1,5 +1,7 @@ """Helper functions for the EIP-198 ModExp precompile tests.""" +from typing import Tuple + from pydantic import Field from ethereum_test_tools import Bytes, TestParameterGroup @@ -24,21 +26,37 @@ class ModExpInput(TestParameterGroup): modulus: Bytes extra_data: Bytes = Field(default_factory=Bytes) raw_input: Bytes | None = None + declared_base_length: int | None = None + declared_exponent_length: int | None = None + declared_modulus_length: int | None = None @property def length_base(self) -> Bytes: """Return the length of the base.""" - return Bytes(len(self.base).to_bytes(32, "big")) + length = ( + self.declared_base_length if self.declared_base_length is not None else len(self.base) + ) + return Bytes(length.to_bytes(32, "big")) @property def length_exponent(self) -> Bytes: """Return the length of the exponent.""" - return Bytes(len(self.exponent).to_bytes(32, "big")) + length = ( + self.declared_exponent_length + if self.declared_exponent_length is not None + else len(self.exponent) + ) + return Bytes(length.to_bytes(32, "big")) @property def length_modulus(self) -> Bytes: """Return the length of the modulus.""" - return Bytes(len(self.modulus).to_bytes(32, "big")) + length = ( + self.declared_modulus_length + if self.declared_modulus_length is not None + else len(self.modulus) + ) + return Bytes(length.to_bytes(32, "big")) def __bytes__(self): """Generate input for the MODEXP precompile.""" @@ -86,6 +104,30 @@ def from_bytes(cls, input_data: Bytes | str) -> "ModExpInput": return cls(base=base, exponent=exponent, modulus=modulus, raw_input=input_data) + def get_declared_lengths(self) -> Tuple[int, int, int]: + """Extract declared lengths from the raw input bytes.""" + raw = self.raw_input if self.raw_input is not None else bytes(self) + if len(raw) < 96: + raw = raw.ljust(96, b"\0") + base_length = int.from_bytes(raw[0:32], byteorder="big") + exponent_length = int.from_bytes(raw[32:64], byteorder="big") + modulus_length = int.from_bytes(raw[64:96], byteorder="big") + return base_length, exponent_length, modulus_length + + def get_exponent_head(self) -> int: + """Get the first 32 bytes of the exponent as an integer.""" + raw = self.raw_input if self.raw_input is not None else bytes(self) + base_length, exponent_length, _ = self.get_declared_lengths() + exp_start = 96 + base_length + + # Extract up to 32 bytes of exponent data + exp_head_bytes = raw[exp_start : exp_start + min(32, exponent_length)] + + # Pad with zeros if less than 32 bytes + exp_head_bytes = exp_head_bytes.rjust(32, b"\0") + + return int.from_bytes(exp_head_bytes[:32], byteorder="big") + class ModExpOutput(TestParameterGroup): """ diff --git a/tests/osaka/eip7823_modexp_upper_bounds/test_modexp_upper_bounds.py b/tests/osaka/eip7823_modexp_upper_bounds/test_modexp_upper_bounds.py index 269e4792964..5ee9ac4b286 100644 --- a/tests/osaka/eip7823_modexp_upper_bounds/test_modexp_upper_bounds.py +++ b/tests/osaka/eip7823_modexp_upper_bounds/test_modexp_upper_bounds.py @@ -6,7 +6,7 @@ import pytest from ethereum_test_forks import Fork, Osaka -from ethereum_test_tools import Account, Alloc, Environment, StateTestFiller, Transaction +from ethereum_test_tools import Account, Alloc, Bytes, Environment, StateTestFiller, Transaction from ethereum_test_tools.vm.opcode import Opcodes as Op from ...byzantium.eip198_modexp_precompile.helpers import ModExpInput, ModExpOutput @@ -23,12 +23,7 @@ def precompile_gas(fork: Fork, mod_exp_input: ModExpInput) -> int: """Calculate gas cost for the ModExp precompile and verify it matches expected gas.""" spec = Spec if fork < Osaka else Spec7883 - calculated_gas = spec.calculate_gas_cost( - len(mod_exp_input.base), - len(mod_exp_input.modulus), - len(mod_exp_input.exponent), - mod_exp_input.exponent, - ) + calculated_gas = spec.calculate_gas_cost(mod_exp_input) return calculated_gas @@ -125,6 +120,15 @@ def precompile_gas(fork: Fork, mod_exp_input: ModExpInput) -> int: ), id="exp_0_base_1_mod_1025", ), + pytest.param( + ModExpInput( + base=b"", + exponent=Bytes("80"), + modulus=b"", + declared_exponent_length=2**64, + ), + id="exp_2_pow_64_base_0_mod_0", + ), ], ) def test_modexp_upper_bounds( @@ -174,10 +178,11 @@ def test_modexp_upper_bounds( protected=True, sender=sender, ) + base_length, exp_length, mod_length = mod_exp_input.get_declared_lengths() if ( - len(mod_exp_input.base) <= MAX_LENGTH_BYTES - and len(mod_exp_input.exponent) <= MAX_LENGTH_BYTES - and len(mod_exp_input.modulus) <= MAX_LENGTH_BYTES + base_length <= MAX_LENGTH_BYTES + and exp_length <= MAX_LENGTH_BYTES + and mod_length <= MAX_LENGTH_BYTES ) or (fork < Osaka and not expensive): output = ModExpOutput(call_success=True, returned_data="0x01") else: diff --git a/tests/osaka/eip7883_modexp_gas_increase/conftest.py b/tests/osaka/eip7883_modexp_gas_increase/conftest.py index df395b90213..366ba81141f 100644 --- a/tests/osaka/eip7883_modexp_gas_increase/conftest.py +++ b/tests/osaka/eip7883_modexp_gas_increase/conftest.py @@ -133,12 +133,7 @@ def precompile_gas( """Calculate gas cost for the ModExp precompile and verify it matches expected gas.""" spec = Spec if fork < Osaka else Spec7883 try: - calculated_gas = spec.calculate_gas_cost( - len(modexp_input.base), - len(modexp_input.modulus), - len(modexp_input.exponent), - modexp_input.exponent, - ) + calculated_gas = spec.calculate_gas_cost(modexp_input) if gas_old is not None and gas_new is not None: expected_gas = gas_old if fork < Osaka else gas_new assert calculated_gas == expected_gas, ( @@ -150,8 +145,9 @@ def precompile_gas( f"({int.from_bytes(modexp_input.exponent, byteorder='big')})" ) return calculated_gas - except Exception as e: - print(f"Warning: Error calculating gas, using minimum: {e}") + except Exception: + # Used for `test_modexp_invalid_inputs` we expect the call to not succeed. + # Return is for completeness. return 500 if fork >= Osaka else 200 diff --git a/tests/osaka/eip7883_modexp_gas_increase/spec.py b/tests/osaka/eip7883_modexp_gas_increase/spec.py index fd69cf1536e..7bda842052c 100644 --- a/tests/osaka/eip7883_modexp_gas_increase/spec.py +++ b/tests/osaka/eip7883_modexp_gas_increase/spec.py @@ -79,14 +79,29 @@ def calculate_iteration_count(cls, exponent_length: int, exponent: bytes) -> int return max(iteration_count, 1) @classmethod - def calculate_gas_cost( - cls, base_length: int, modulus_length: int, exponent_length: int, exponent: bytes - ) -> int: - """Calculate the ModExp gas cost according to EIP-2565 specification.""" + def calculate_gas_cost(cls, modexp_input: ModExpInput) -> int: + """ + Calculate the ModExp gas cost according to EIP-2565 specification, overridden by the + constants within `Spec7883` when calculating for the EIP-7883 specification. + + This handles length mismatch cases by using declared lengths from the raw input + and only the first 32 bytes of exponent data for iteration calculation. + """ + base_length, exponent_length, modulus_length = modexp_input.get_declared_lengths() + exponent_head = modexp_input.get_exponent_head() multiplication_complexity = cls.calculate_multiplication_complexity( base_length, modulus_length ) - iteration_count = cls.calculate_iteration_count(exponent_length, exponent) + if exponent_length <= cls.EXPONENT_THRESHOLD and exponent_head == 0: + iteration_count = 0 + elif exponent_length <= cls.EXPONENT_THRESHOLD: + iteration_count = exponent_head.bit_length() - 1 if exponent_head > 0 else 0 + else: + # For large exponents: length_part + bits from first 32 bytes + length_part = cls.EXPONENT_BYTE_MULTIPLIER * (exponent_length - 32) + bits_part = exponent_head.bit_length() - 1 if exponent_head > 0 else 0 + iteration_count = length_part + bits_part + iteration_count = max(iteration_count, 1) return max(cls.MIN_GAS, (multiplication_complexity * iteration_count // cls.GAS_DIVISOR))