Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions tests/byzantium/eip198_modexp_precompile/helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Helper functions for the EIP-198 ModExp precompile tests."""

from typing import Tuple

from pydantic import Field

from ethereum_test_tools import Bytes, TestParameterGroup
Expand All @@ -24,21 +26,37 @@ class ModExpInput(TestParameterGroup):
modulus: Bytes
extra_data: Bytes = Field(default_factory=Bytes)
raw_input: Bytes | None = None
declared_base_length: int | None = None
declared_exponent_length: int | None = None
declared_modulus_length: int | None = None

@property
def length_base(self) -> Bytes:
"""Return the length of the base."""
return Bytes(len(self.base).to_bytes(32, "big"))
length = (
self.declared_base_length if self.declared_base_length is not None else len(self.base)
)
return Bytes(length.to_bytes(32, "big"))

@property
def length_exponent(self) -> Bytes:
"""Return the length of the exponent."""
return Bytes(len(self.exponent).to_bytes(32, "big"))
length = (
self.declared_exponent_length
if self.declared_exponent_length is not None
else len(self.exponent)
)
return Bytes(length.to_bytes(32, "big"))

@property
def length_modulus(self) -> Bytes:
"""Return the length of the modulus."""
return Bytes(len(self.modulus).to_bytes(32, "big"))
length = (
self.declared_modulus_length
if self.declared_modulus_length is not None
else len(self.modulus)
)
return Bytes(length.to_bytes(32, "big"))

def __bytes__(self):
"""Generate input for the MODEXP precompile."""
Expand Down Expand Up @@ -86,6 +104,30 @@ def from_bytes(cls, input_data: Bytes | str) -> "ModExpInput":

return cls(base=base, exponent=exponent, modulus=modulus, raw_input=input_data)

def get_declared_lengths(self) -> Tuple[int, int, int]:
"""Extract declared lengths from the raw input bytes."""
raw = self.raw_input if self.raw_input is not None else bytes(self)
if len(raw) < 96:
raw = raw.ljust(96, b"\0")
base_length = int.from_bytes(raw[0:32], byteorder="big")
exponent_length = int.from_bytes(raw[32:64], byteorder="big")
modulus_length = int.from_bytes(raw[64:96], byteorder="big")
return base_length, exponent_length, modulus_length

def get_exponent_head(self) -> int:
"""Get the first 32 bytes of the exponent as an integer."""
raw = self.raw_input if self.raw_input is not None else bytes(self)
base_length, exponent_length, _ = self.get_declared_lengths()
exp_start = 96 + base_length

# Extract up to 32 bytes of exponent data
exp_head_bytes = raw[exp_start : exp_start + min(32, exponent_length)]

# Pad with zeros if less than 32 bytes
exp_head_bytes = exp_head_bytes.rjust(32, b"\0")

return int.from_bytes(exp_head_bytes[:32], byteorder="big")


class ModExpOutput(TestParameterGroup):
"""
Expand Down
25 changes: 15 additions & 10 deletions tests/osaka/eip7823_modexp_upper_bounds/test_modexp_upper_bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

from ethereum_test_forks import Fork, Osaka
from ethereum_test_tools import Account, Alloc, Environment, StateTestFiller, Transaction
from ethereum_test_tools import Account, Alloc, Bytes, Environment, StateTestFiller, Transaction
from ethereum_test_tools.vm.opcode import Opcodes as Op

from ...byzantium.eip198_modexp_precompile.helpers import ModExpInput, ModExpOutput
Expand All @@ -23,12 +23,7 @@
def precompile_gas(fork: Fork, mod_exp_input: ModExpInput) -> int:
"""Calculate gas cost for the ModExp precompile and verify it matches expected gas."""
spec = Spec if fork < Osaka else Spec7883
calculated_gas = spec.calculate_gas_cost(
len(mod_exp_input.base),
len(mod_exp_input.modulus),
len(mod_exp_input.exponent),
mod_exp_input.exponent,
)
calculated_gas = spec.calculate_gas_cost(mod_exp_input)
return calculated_gas


Expand Down Expand Up @@ -125,6 +120,15 @@ def precompile_gas(fork: Fork, mod_exp_input: ModExpInput) -> int:
),
id="exp_0_base_1_mod_1025",
),
pytest.param(
ModExpInput(
base=b"",
exponent=Bytes("80"),
modulus=b"",
declared_exponent_length=2**64,
),
id="exp_2_pow_64_base_0_mod_0",
),
],
)
def test_modexp_upper_bounds(
Expand Down Expand Up @@ -174,10 +178,11 @@ def test_modexp_upper_bounds(
protected=True,
sender=sender,
)
base_length, exp_length, mod_length = mod_exp_input.get_declared_lengths()
if (
len(mod_exp_input.base) <= MAX_LENGTH_BYTES
and len(mod_exp_input.exponent) <= MAX_LENGTH_BYTES
and len(mod_exp_input.modulus) <= MAX_LENGTH_BYTES
base_length <= MAX_LENGTH_BYTES
and exp_length <= MAX_LENGTH_BYTES
and mod_length <= MAX_LENGTH_BYTES
) or (fork < Osaka and not expensive):
output = ModExpOutput(call_success=True, returned_data="0x01")
else:
Expand Down
12 changes: 4 additions & 8 deletions tests/osaka/eip7883_modexp_gas_increase/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,7 @@ def precompile_gas(
"""Calculate gas cost for the ModExp precompile and verify it matches expected gas."""
spec = Spec if fork < Osaka else Spec7883
try:
calculated_gas = spec.calculate_gas_cost(
len(modexp_input.base),
len(modexp_input.modulus),
len(modexp_input.exponent),
modexp_input.exponent,
)
calculated_gas = spec.calculate_gas_cost(modexp_input)
if gas_old is not None and gas_new is not None:
expected_gas = gas_old if fork < Osaka else gas_new
assert calculated_gas == expected_gas, (
Expand All @@ -150,8 +145,9 @@ def precompile_gas(
f"({int.from_bytes(modexp_input.exponent, byteorder='big')})"
)
return calculated_gas
except Exception as e:
print(f"Warning: Error calculating gas, using minimum: {e}")
except Exception:
# Used for `test_modexp_invalid_inputs` we expect the call to not succeed.
# Return is for completeness.
return 500 if fork >= Osaka else 200


Expand Down
25 changes: 20 additions & 5 deletions tests/osaka/eip7883_modexp_gas_increase/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,29 @@ def calculate_iteration_count(cls, exponent_length: int, exponent: bytes) -> int
return max(iteration_count, 1)

@classmethod
def calculate_gas_cost(
cls, base_length: int, modulus_length: int, exponent_length: int, exponent: bytes
) -> int:
"""Calculate the ModExp gas cost according to EIP-2565 specification."""
def calculate_gas_cost(cls, modexp_input: ModExpInput) -> int:
"""
Calculate the ModExp gas cost according to EIP-2565 specification, overridden by the
constants within `Spec7883` when calculating for the EIP-7883 specification.

This handles length mismatch cases by using declared lengths from the raw input
and only the first 32 bytes of exponent data for iteration calculation.
"""
base_length, exponent_length, modulus_length = modexp_input.get_declared_lengths()
exponent_head = modexp_input.get_exponent_head()
multiplication_complexity = cls.calculate_multiplication_complexity(
base_length, modulus_length
)
iteration_count = cls.calculate_iteration_count(exponent_length, exponent)
if exponent_length <= cls.EXPONENT_THRESHOLD and exponent_head == 0:
iteration_count = 0
elif exponent_length <= cls.EXPONENT_THRESHOLD:
iteration_count = exponent_head.bit_length() - 1 if exponent_head > 0 else 0
else:
# For large exponents: length_part + bits from first 32 bytes
length_part = cls.EXPONENT_BYTE_MULTIPLIER * (exponent_length - 32)
bits_part = exponent_head.bit_length() - 1 if exponent_head > 0 else 0
iteration_count = length_part + bits_part
iteration_count = max(iteration_count, 1)
return max(cls.MIN_GAS, (multiplication_complexity * iteration_count // cls.GAS_DIVISOR))


Expand Down
Loading