diff --git a/tests/byzantium/eip198_modexp_precompile/helpers.py b/tests/byzantium/eip198_modexp_precompile/helpers.py index 78fdc56fbad..9cc44fd25a1 100644 --- a/tests/byzantium/eip198_modexp_precompile/helpers.py +++ b/tests/byzantium/eip198_modexp_precompile/helpers.py @@ -1,5 +1,7 @@ """Helper functions for the EIP-198 ModExp precompile tests.""" +from typing import Tuple + from pydantic import Field from ethereum_test_tools import Bytes, TestParameterGroup @@ -24,21 +26,37 @@ class ModExpInput(TestParameterGroup): modulus: Bytes extra_data: Bytes = Field(default_factory=Bytes) raw_input: Bytes | None = None + declared_base_length: int | None = None + declared_exponent_length: int | None = None + declared_modulus_length: int | None = None @property def length_base(self) -> Bytes: """Return the length of the base.""" - return Bytes(len(self.base).to_bytes(32, "big")) + length = ( + self.declared_base_length if self.declared_base_length is not None else len(self.base) + ) + return Bytes(length.to_bytes(32, "big")) @property def length_exponent(self) -> Bytes: """Return the length of the exponent.""" - return Bytes(len(self.exponent).to_bytes(32, "big")) + length = ( + self.declared_exponent_length + if self.declared_exponent_length is not None + else len(self.exponent) + ) + return Bytes(length.to_bytes(32, "big")) @property def length_modulus(self) -> Bytes: """Return the length of the modulus.""" - return Bytes(len(self.modulus).to_bytes(32, "big")) + length = ( + self.declared_modulus_length + if self.declared_modulus_length is not None + else len(self.modulus) + ) + return Bytes(length.to_bytes(32, "big")) def __bytes__(self): """Generate input for the MODEXP precompile.""" @@ -86,6 +104,30 @@ def from_bytes(cls, input_data: Bytes | str) -> "ModExpInput": return cls(base=base, exponent=exponent, modulus=modulus, raw_input=input_data) + def get_declared_lengths(self) -> Tuple[int, int, int]: + """Extract declared lengths from the raw input bytes.""" + raw = self.raw_input if self.raw_input is not None else bytes(self) + if len(raw) < 96: + raw = raw.ljust(96, b"\0") + base_length = int.from_bytes(raw[0:32], byteorder="big") + exponent_length = int.from_bytes(raw[32:64], byteorder="big") + modulus_length = int.from_bytes(raw[64:96], byteorder="big") + return base_length, exponent_length, modulus_length + + def get_exponent_head(self) -> int: + """Get the first 32 bytes of the exponent as an integer.""" + raw = self.raw_input if self.raw_input is not None else bytes(self) + base_length, exponent_length, _ = self.get_declared_lengths() + exp_start = 96 + base_length + + # Extract up to 32 bytes of exponent data + exp_head_bytes = raw[exp_start : exp_start + min(32, exponent_length)] + + # Pad with zeros if less than 32 bytes + exp_head_bytes = exp_head_bytes.rjust(32, b"\0") + + return int.from_bytes(exp_head_bytes[:32], byteorder="big") + class ModExpOutput(TestParameterGroup): """ diff --git a/tests/osaka/eip7823_modexp_upper_bounds/test_modexp_upper_bounds.py b/tests/osaka/eip7823_modexp_upper_bounds/test_modexp_upper_bounds.py index 269e4792964..5ee9ac4b286 100644 --- a/tests/osaka/eip7823_modexp_upper_bounds/test_modexp_upper_bounds.py +++ b/tests/osaka/eip7823_modexp_upper_bounds/test_modexp_upper_bounds.py @@ -6,7 +6,7 @@ import pytest from ethereum_test_forks import Fork, Osaka -from ethereum_test_tools import Account, Alloc, Environment, StateTestFiller, Transaction +from ethereum_test_tools import Account, Alloc, Bytes, Environment, StateTestFiller, Transaction from ethereum_test_tools.vm.opcode import Opcodes as Op from ...byzantium.eip198_modexp_precompile.helpers import ModExpInput, ModExpOutput @@ -23,12 +23,7 @@ def precompile_gas(fork: Fork, mod_exp_input: ModExpInput) -> int: """Calculate gas cost for the ModExp precompile and verify it matches expected gas.""" spec = Spec if fork < Osaka else Spec7883 - calculated_gas = spec.calculate_gas_cost( - len(mod_exp_input.base), - len(mod_exp_input.modulus), - len(mod_exp_input.exponent), - mod_exp_input.exponent, - ) + calculated_gas = spec.calculate_gas_cost(mod_exp_input) return calculated_gas @@ -125,6 +120,15 @@ def precompile_gas(fork: Fork, mod_exp_input: ModExpInput) -> int: ), id="exp_0_base_1_mod_1025", ), + pytest.param( + ModExpInput( + base=b"", + exponent=Bytes("80"), + modulus=b"", + declared_exponent_length=2**64, + ), + id="exp_2_pow_64_base_0_mod_0", + ), ], ) def test_modexp_upper_bounds( @@ -174,10 +178,11 @@ def test_modexp_upper_bounds( protected=True, sender=sender, ) + base_length, exp_length, mod_length = mod_exp_input.get_declared_lengths() if ( - len(mod_exp_input.base) <= MAX_LENGTH_BYTES - and len(mod_exp_input.exponent) <= MAX_LENGTH_BYTES - and len(mod_exp_input.modulus) <= MAX_LENGTH_BYTES + base_length <= MAX_LENGTH_BYTES + and exp_length <= MAX_LENGTH_BYTES + and mod_length <= MAX_LENGTH_BYTES ) or (fork < Osaka and not expensive): output = ModExpOutput(call_success=True, returned_data="0x01") else: diff --git a/tests/osaka/eip7883_modexp_gas_increase/conftest.py b/tests/osaka/eip7883_modexp_gas_increase/conftest.py index df395b90213..366ba81141f 100644 --- a/tests/osaka/eip7883_modexp_gas_increase/conftest.py +++ b/tests/osaka/eip7883_modexp_gas_increase/conftest.py @@ -133,12 +133,7 @@ def precompile_gas( """Calculate gas cost for the ModExp precompile and verify it matches expected gas.""" spec = Spec if fork < Osaka else Spec7883 try: - calculated_gas = spec.calculate_gas_cost( - len(modexp_input.base), - len(modexp_input.modulus), - len(modexp_input.exponent), - modexp_input.exponent, - ) + calculated_gas = spec.calculate_gas_cost(modexp_input) if gas_old is not None and gas_new is not None: expected_gas = gas_old if fork < Osaka else gas_new assert calculated_gas == expected_gas, ( @@ -150,8 +145,9 @@ def precompile_gas( f"({int.from_bytes(modexp_input.exponent, byteorder='big')})" ) return calculated_gas - except Exception as e: - print(f"Warning: Error calculating gas, using minimum: {e}") + except Exception: + # Used for `test_modexp_invalid_inputs` we expect the call to not succeed. + # Return is for completeness. return 500 if fork >= Osaka else 200 diff --git a/tests/osaka/eip7883_modexp_gas_increase/spec.py b/tests/osaka/eip7883_modexp_gas_increase/spec.py index fd69cf1536e..74908e60b9e 100644 --- a/tests/osaka/eip7883_modexp_gas_increase/spec.py +++ b/tests/osaka/eip7883_modexp_gas_increase/spec.py @@ -61,32 +61,36 @@ def calculate_multiplication_complexity(cls, base_length: int, modulus_length: i return cls.LARGE_BASE_MODULUS_MULTIPLIER * words**2 @classmethod - def calculate_iteration_count(cls, exponent_length: int, exponent: bytes) -> int: - """Calculate the iteration count of the ModExp precompile.""" - iteration_count = 0 - exponent_value = int.from_bytes(exponent, byteorder="big") - if exponent_length <= cls.EXPONENT_THRESHOLD and exponent_value == 0: + def calculate_iteration_count(cls, modexp_input: ModExpInput) -> int: + """ + Calculate the iteration count of the ModExp precompile. + This handles length mismatch cases by using declared lengths from the raw input + and only the first 32 bytes of exponent data for iteration calculation. + """ + _, exponent_length, _ = modexp_input.get_declared_lengths() + exponent_head = modexp_input.get_exponent_head() + if exponent_length <= cls.EXPONENT_THRESHOLD and exponent_head == 0: iteration_count = 0 elif exponent_length <= cls.EXPONENT_THRESHOLD: - iteration_count = exponent_value.bit_length() - 1 - elif exponent_length > cls.EXPONENT_THRESHOLD: - exponent_head = int.from_bytes(exponent[0:32], byteorder="big") + iteration_count = exponent_head.bit_length() - 1 if exponent_head > 0 else 0 + else: + # For large exponents: length_part + bits from first 32 bytes length_part = cls.EXPONENT_BYTE_MULTIPLIER * (exponent_length - 32) - bits_part = exponent_head.bit_length() - if bits_part > 0: - bits_part -= 1 + bits_part = exponent_head.bit_length() - 1 if exponent_head > 0 else 0 iteration_count = length_part + bits_part return max(iteration_count, 1) @classmethod - def calculate_gas_cost( - cls, base_length: int, modulus_length: int, exponent_length: int, exponent: bytes - ) -> int: - """Calculate the ModExp gas cost according to EIP-2565 specification.""" + def calculate_gas_cost(cls, modexp_input: ModExpInput) -> int: + """ + Calculate the ModExp gas cost according to EIP-2565 specification, overridden by the + constants within `Spec7883` when calculating for the EIP-7883 specification. + """ + base_length, _, modulus_length = modexp_input.get_declared_lengths() multiplication_complexity = cls.calculate_multiplication_complexity( base_length, modulus_length ) - iteration_count = cls.calculate_iteration_count(exponent_length, exponent) + iteration_count = cls.calculate_iteration_count(modexp_input) return max(cls.MIN_GAS, (multiplication_complexity * iteration_count // cls.GAS_DIVISOR))