diff --git a/eth/precompiles/ecadd.py b/eth/precompiles/ecadd.py index 27938b0e24..a12cb8d7c0 100644 --- a/eth/precompiles/ecadd.py +++ b/eth/precompiles/ecadd.py @@ -32,7 +32,7 @@ def ecadd(computation: BaseComputation) -> BaseComputation: computation.consume_gas(constants.GAS_ECADD, reason='ECADD Precompile') try: - result = _ecadd(computation.msg.data) + result = _ecadd(computation.msg.data_as_bytes) except ValidationError: raise VMError("Invalid ECADD parameters") diff --git a/eth/precompiles/ecmul.py b/eth/precompiles/ecmul.py index 0df5238c5c..05366695c8 100644 --- a/eth/precompiles/ecmul.py +++ b/eth/precompiles/ecmul.py @@ -32,7 +32,7 @@ def ecmul(computation: BaseComputation) -> BaseComputation: computation.consume_gas(constants.GAS_ECMUL, reason='ECMUL Precompile') try: - result = _ecmull(computation.msg.data) + result = _ecmull(computation.msg.data_as_bytes) except ValidationError: raise VMError("Invalid ECMUL parameters") diff --git a/eth/precompiles/ecpairing.py b/eth/precompiles/ecpairing.py index d00b23ae52..5a5874ca02 100644 --- a/eth/precompiles/ecpairing.py +++ b/eth/precompiles/ecpairing.py @@ -19,6 +19,10 @@ VMError, ) +from eth.typing import ( + BytesOrView, +) + from eth.utils.bn128 import ( validate_point, FQP_point_to_FQ2_point, @@ -60,7 +64,7 @@ def ecpairing(computation: BaseComputation) -> BaseComputation: return computation -def _ecpairing(data: bytes) -> bool: +def _ecpairing(data: BytesOrView) -> bool: exponent = bn128.FQ12.one() processing_pipeline = ( diff --git a/eth/precompiles/ecrecover.py b/eth/precompiles/ecrecover.py index d04ec02b24..263bf54c3c 100644 --- a/eth/precompiles/ecrecover.py +++ b/eth/precompiles/ecrecover.py @@ -27,16 +27,17 @@ def ecrecover(computation: BaseComputation) -> BaseComputation: computation.consume_gas(constants.GAS_ECRECOVER, reason="ECRecover Precompile") - raw_message_hash = computation.msg.data[:32] + data = computation.msg.data_as_bytes + raw_message_hash = data[:32] message_hash = pad32r(raw_message_hash) - v_bytes = pad32r(computation.msg.data[32:64]) + v_bytes = pad32r(data[32:64]) v = big_endian_to_int(v_bytes) - r_bytes = pad32r(computation.msg.data[64:96]) + r_bytes = pad32r(data[64:96]) r = big_endian_to_int(r_bytes) - s_bytes = pad32r(computation.msg.data[96:128]) + s_bytes = pad32r(data[96:128]) s = big_endian_to_int(s_bytes) try: diff --git a/eth/precompiles/identity.py b/eth/precompiles/identity.py index a2661559c0..9ad722c0c7 100644 --- a/eth/precompiles/identity.py +++ b/eth/precompiles/identity.py @@ -14,5 +14,5 @@ def identity(computation: BaseComputation) -> BaseComputation: computation.consume_gas(gas_fee, reason="Identity Precompile") - computation.output = computation.msg.data + computation.output = computation.msg.data_as_bytes return computation diff --git a/eth/precompiles/modexp.py b/eth/precompiles/modexp.py index d6d0772653..e3b3e66485 100644 --- a/eth/precompiles/modexp.py +++ b/eth/precompiles/modexp.py @@ -125,12 +125,14 @@ def modexp(computation: BaseComputation) -> BaseComputation: """ https://github.com/ethereum/EIPs/pull/198 """ - gas_fee = _compute_modexp_gas_fee(computation.msg.data) + data = computation.msg.data_as_bytes + + gas_fee = _compute_modexp_gas_fee(data) computation.consume_gas(gas_fee, reason='MODEXP Precompile') - result = _modexp(computation.msg.data) + result = _modexp(data) - _, _, modulus_length = _extract_lengths(computation.msg.data) + _, _, modulus_length = _extract_lengths(data) # Modulo 0 is undefined, return zero # https://math.stackexchange.com/questions/516251/why-is-n-mod-0-undefined diff --git a/eth/typing.py b/eth/typing.py index af3f075432..b19358eb1f 100644 --- a/eth/typing.py +++ b/eth/typing.py @@ -57,6 +57,8 @@ GenesisDict = Dict[str, Union[int, BlockNumber, bytes, Hash32]] +BytesOrView = Union[bytes, memoryview] + Normalizer = Callable[[Dict[Any, Any]], Dict[str, Any]] RawAccountDetails = TypedDict('RawAccountDetails', diff --git a/eth/validation.py b/eth/validation.py index bdec8fe8fc..8cf75a17d5 100644 --- a/eth/validation.py +++ b/eth/validation.py @@ -40,6 +40,10 @@ UINT_256_MAX, ) +from eth.typing import ( + BytesOrView, +) + if TYPE_CHECKING: from eth.vm.base import BaseVM # noqa: F401 @@ -51,6 +55,14 @@ def validate_is_bytes(value: bytes, title: str="Value") -> None: ) +def validate_is_bytes_or_view(value: BytesOrView, title: str="Value") -> None: + if isinstance(value, (bytes, memoryview)): + return + raise ValidationError( + "{title} must be bytes or memoryview. Got {0}".format(type(value), title=title) + ) + + def validate_is_integer(value: Union[int, bool], title: str="Value") -> None: if not isinstance(value, int) or isinstance(value, bool): raise ValidationError( diff --git a/eth/vm/computation.py b/eth/vm/computation.py index 603d0a0f2f..e4298d1ded 100644 --- a/eth/vm/computation.py +++ b/eth/vm/computation.py @@ -29,6 +29,9 @@ Halt, VMError, ) +from eth.typing import ( + BytesOrView, +) from eth.tools.logging import ( ExtendedDebugLogger, ) @@ -243,12 +246,18 @@ def memory_write(self, start_position: int, size: int, value: bytes) -> None: """ return self._memory.write(start_position, size, value) - def memory_read(self, start_position: int, size: int) -> bytes: + def memory_read(self, start_position: int, size: int) -> memoryview: """ - Read and return ``size`` bytes from memory starting at ``start_position``. + Read and return a view of ``size`` bytes from memory starting at ``start_position``. """ return self._memory.read(start_position, size) + def memory_read_bytes(self, start_position: int, size: int) -> bytes: + """ + Read and return ``size`` bytes from memory starting at ``start_position``. + """ + return self._memory.read_bytes(start_position, size) + # # Gas Consumption # @@ -360,7 +369,7 @@ def prepare_child_message(self, gas: int, to: Address, value: int, - data: bytes, + data: BytesOrView, code: bytes, **kwargs: Any) -> Message: """ diff --git a/eth/vm/logic/context.py b/eth/vm/logic/context.py index 117da4a74e..8bc9c574a0 100644 --- a/eth/vm/logic/context.py +++ b/eth/vm/logic/context.py @@ -42,7 +42,7 @@ def calldataload(computation: BaseComputation) -> None: """ start_position = computation.stack_pop(type_hint=constants.UINT256) - value = computation.msg.data[start_position:start_position + 32] + value = computation.msg.data_as_bytes[start_position:start_position + 32] padded_value = value.ljust(32, b'\x00') normalized_value = padded_value.lstrip(b'\x00') @@ -68,7 +68,9 @@ def calldatacopy(computation: BaseComputation) -> None: computation.consume_gas(copy_gas_cost, reason="CALLDATACOPY fee") - value = computation.msg.data[calldata_start_position: calldata_start_position + size] + value = computation.msg.data_as_bytes[ + calldata_start_position: calldata_start_position + size + ] padded_value = value.ljust(size, b'\x00') computation.memory_write(mem_start_position, size, padded_value) diff --git a/eth/vm/logic/logging.py b/eth/vm/logic/logging.py index ce941ef5b7..1c6ce515d4 100644 --- a/eth/vm/logic/logging.py +++ b/eth/vm/logic/logging.py @@ -30,7 +30,7 @@ def log_XX(computation: BaseComputation, topic_count: int) -> None: ) computation.extend_memory(mem_start_position, size) - log_data = computation.memory_read(mem_start_position, size) + log_data = computation.memory_read_bytes(mem_start_position, size) computation.add_log_entry( account=computation.msg.storage_address, diff --git a/eth/vm/logic/memory.py b/eth/vm/logic/memory.py index 30a395ff22..e10d1df832 100644 --- a/eth/vm/logic/memory.py +++ b/eth/vm/logic/memory.py @@ -32,7 +32,7 @@ def mload(computation: BaseComputation) -> None: computation.extend_memory(start_position, 32) - value = computation.memory_read(start_position, 32) + value = computation.memory_read_bytes(start_position, 32) computation.stack_push(value) diff --git a/eth/vm/logic/sha3.py b/eth/vm/logic/sha3.py index aee319e28f..7a77907df5 100644 --- a/eth/vm/logic/sha3.py +++ b/eth/vm/logic/sha3.py @@ -12,7 +12,7 @@ def sha3(computation: BaseComputation) -> None: computation.extend_memory(start_position, size) - sha3_bytes = computation.memory_read(start_position, size) + sha3_bytes = computation.memory_read_bytes(start_position, size) word_count = ceil32(len(sha3_bytes)) // 32 gas_cost = constants.GAS_SHA3WORD * word_count diff --git a/eth/vm/logic/system.py b/eth/vm/logic/system.py index e3f12123ef..9e97e8a777 100644 --- a/eth/vm/logic/system.py +++ b/eth/vm/logic/system.py @@ -32,8 +32,7 @@ def return_op(computation: BaseComputation) -> None: computation.extend_memory(start_position, size) - output = computation.memory_read(start_position, size) - computation.output = bytes(output) + computation.output = computation.memory_read_bytes(start_position, size) raise Halt('RETURN') @@ -42,8 +41,7 @@ def revert(computation: BaseComputation) -> None: computation.extend_memory(start_position, size) - output = computation.memory_read(start_position, size) - computation.output = bytes(output) + computation.output = computation.memory_read_bytes(start_position, size) raise Revert(computation.output) @@ -163,7 +161,9 @@ def __call__(self, computation: BaseComputation) -> None: computation.stack_push(0) return - call_data = computation.memory_read(stack_data.memory_start, stack_data.memory_length) + call_data = computation.memory_read_bytes( + stack_data.memory_start, stack_data.memory_length + ) create_msg_gas = self.max_child_gas_modifier( computation.get_gas_remaining() diff --git a/eth/vm/memory.py b/eth/vm/memory.py index 5e08e7147b..29daacb9e1 100644 --- a/eth/vm/memory.py +++ b/eth/vm/memory.py @@ -32,7 +32,16 @@ def extend(self, start_position: int, size: int) -> None: return size_to_extend = new_size - len(self) - self._bytes.extend(itertools.repeat(0, size_to_extend)) + try: + self._bytes.extend(itertools.repeat(0, size_to_extend)) + except BufferError: + # we can't extend the buffer (which might involve relocating it) if a + # memoryview (which stores a pointer into the buffer) has been created by + # read() and not released. Callers of read() will never try to write to the + # buffer so we're not missing anything by making a new buffer and forgetting + # about the old one. We're keeping too much memory around but this is still a + # net savings over having read() return a new bytes() object every time. + self._bytes = self._bytes + bytearray(size_to_extend) def __len__(self) -> int: return len(self._bytes) @@ -51,8 +60,14 @@ def write(self, start_position: int, size: int, value: bytes) -> None: for idx, v in enumerate(value): self._bytes[start_position + idx] = v - def read(self, start_position: int, size: int) -> bytes: + def read(self, start_position: int, size: int) -> memoryview: """ - Read a value from memory. + Return a view into the memory + """ + return memoryview(self._bytes)[start_position:start_position + size] + + def read_bytes(self, start_position: int, size: int) -> bytes: + """ + Read a value from memory and return a fresh bytes instance """ return bytes(self._bytes[start_position:start_position + size]) diff --git a/eth/vm/message.py b/eth/vm/message.py index 4969646582..281b955f22 100644 --- a/eth/vm/message.py +++ b/eth/vm/message.py @@ -5,9 +5,13 @@ from eth.constants import ( CREATE_CONTRACT_ADDRESS, ) +from eth.typing import ( + BytesOrView, +) from eth.validation import ( validate_canonical_address, validate_is_bytes, + validate_is_bytes_or_view, validate_is_integer, validate_gte, validate_uint256, @@ -31,7 +35,7 @@ def __init__(self, to: Address, sender: Address, value: int, - data: bytes, + data: BytesOrView, code: bytes, depth: int=0, create_address: Address=None, @@ -51,7 +55,7 @@ def __init__(self, validate_uint256(value, title="Message.value") self.value = value - validate_is_bytes(data, title="Message.data") + validate_is_bytes_or_view(data, title="Message.data") self.data = data validate_is_integer(depth, title="Message.depth") @@ -100,3 +104,7 @@ def storage_address(self, value: Address) -> None: @property def is_create(self) -> bool: return self.to == CREATE_CONTRACT_ADDRESS + + @property + def data_as_bytes(self) -> bytes: + return bytes(self.data)