diff --git a/src/ethereum_spec_tools/evm_tools/daemon.py b/src/ethereum_spec_tools/evm_tools/daemon.py index 1e42aaa3a5..2acf6fd5a0 100644 --- a/src/ethereum_spec_tools/evm_tools/daemon.py +++ b/src/ethereum_spec_tools/evm_tools/daemon.py @@ -42,56 +42,68 @@ def log_request( def do_POST(self) -> None: from . import main - content_length = int(self.headers["Content-Length"]) - content_bytes = self.rfile.read(content_length) - content = json.loads(content_bytes) - - input_string = json.dumps(content["input"]) - input = StringIO(input_string) - - args = [ - "t8n", - "--input.env=stdin", - "--input.alloc=stdin", - "--input.txs=stdin", - "--output.result=stdout", - "--output.body=stdout", - "--output.alloc=stdout", - f"--state.fork={content['state']['fork']}", - f"--state.chainid={content['state']['chainid']}", - f"--state.reward={content['state']['reward']}", - ] - - trace = content.get("trace", False) - output_basedir = content.get("output-basedir") - if trace: - if not output_basedir: - raise ValueError( - "`output-basedir` should be provided when `--trace` " - "is enabled." + try: + content_length = int(self.headers["Content-Length"]) + content_bytes = self.rfile.read(content_length) + content = json.loads(content_bytes) + + input_string = json.dumps(content["input"]) + input = StringIO(input_string) + + args = [ + "t8n", + "--input.env=stdin", + "--input.alloc=stdin", + "--input.txs=stdin", + "--output.result=stdout", + "--output.body=stdout", + "--output.alloc=stdout", + f"--state.fork={content['state']['fork']}", + f"--state.chainid={content['state']['chainid']}", + f"--state.reward={content['state']['reward']}", + ] + + trace = content.get("trace", False) + output_basedir = content.get("output-basedir") + if trace: + if not output_basedir: + raise ValueError( + "`output-basedir` should be provided when `--trace` " + "is enabled." + ) + # send full trace output if ``trace`` is ``True`` + args.extend( + [ + "--trace", + "--trace.memory", + "--trace.returndata", + f"--output.basedir={output_basedir}", + ] ) - # send full trace output if ``trace`` is ``True`` - args.extend( - [ - "--trace", - "--trace.memory", - "--trace.returndata", - f"--output.basedir={output_basedir}", - ] - ) - query_string = urlparse(self.path).query - if query_string: - query = parse_qs( - query_string, - keep_blank_values=True, - strict_parsing=True, - errors="strict", - ) - args += query.get("arg", []) + count_opcodes = content.get("count-opcodes", False) + if count_opcodes: + # send opcode counts if ``count-opcodes`` is ``True`` + args.extend(["--opcode.count", "stdout"]) + + query_string = urlparse(self.path).query + if query_string: + query = parse_qs( + query_string, + keep_blank_values=True, + strict_parsing=True, + errors="strict", + ) + args += query.get("arg", []) + except Exception as e: + self.send_response(500) + self.send_header("Content-Type", "text/plain") + self.end_headers() + self.wfile.write(str(e).encode("utf-8")) + raise self.send_response(200) - self.send_header("Content-type", "application/octet-stream") + self.send_header("Content-Type", "application/octet-stream") self.end_headers() # `self.wfile` is missing the `name` attribute so it doesn't strictly diff --git a/src/ethereum_spec_tools/evm_tools/t8n/__init__.py b/src/ethereum_spec_tools/evm_tools/t8n/__init__.py index f00e049357..2fad6816df 100644 --- a/src/ethereum_spec_tools/evm_tools/t8n/__init__.py +++ b/src/ethereum_spec_tools/evm_tools/t8n/__init__.py @@ -6,7 +6,7 @@ import fnmatch import json import os -from typing import Any, TextIO +from typing import Any, Final, TextIO, Type, TypeVar from ethereum_rlp import rlp from ethereum_types.numeric import U64, U256, Uint @@ -24,9 +24,13 @@ parse_hex_or_int, ) from .env import Env +from .evm_trace.count import CountTracer from .evm_trace.eip3155 import Eip3155Tracer +from .evm_trace.group import GroupTracer from .t8n_types import Alloc, Result, Txs +T = TypeVar("T") + def t8n_arguments(subparsers: argparse._SubParsersAction) -> None: """ @@ -72,12 +76,16 @@ def t8n_arguments(subparsers: argparse._SubParsersAction) -> None: t8n_parser.add_argument("--trace.nostack", action="store_true") t8n_parser.add_argument("--trace.returndata", action="store_true") + t8n_parser.add_argument("--opcode.count", dest="opcode_count", type=str) + t8n_parser.add_argument("--state-test", action="store_true") class T8N(Load): """The class that carries out the transition""" + tracers: Final[GroupTracer | None] + def __init__( self, options: Any, out_file: TextIO, in_file: TextIO ) -> None: @@ -100,11 +108,13 @@ def __init__( ) self.fork = ForkLoad(fork_module) + tracers = GroupTracer() + if self.options.trace: trace_memory = getattr(self.options, "trace.memory", False) trace_stack = not getattr(self.options, "trace.nostack", False) trace_return_data = getattr(self.options, "trace.returndata") - trace.set_evm_trace( + tracers.add( Eip3155Tracer( trace_memory=trace_memory, trace_stack=trace_stack, @@ -112,6 +122,19 @@ def __init__( output_basedir=self.options.output_basedir, ) ) + + if self.options.opcode_count is not None: + tracers.add(CountTracer()) + + maybe_tracers: GroupTracer | None + if tracers.tracers: + trace.set_evm_trace(tracers) + maybe_tracers = tracers + else: + maybe_tracers = None + + self.tracers = maybe_tracers + self.logger = get_stream_logger("T8N") super().__init__( @@ -127,6 +150,15 @@ def __init__( self.env.block_difficulty, self.env.base_fee_per_gas ) + def _tracer(self, type_: Type[T]) -> T: + group = self.tracers + if group is None: + raise Exception("no tracer configured") + found = next((x for x in group.tracers if isinstance(x, type_)), None) + if found is None: + raise Exception(f"no tracer of type `{type_}` found") + return found + def block_environment(self) -> Any: """ Create the environment for the transaction. The keyword @@ -310,7 +342,7 @@ def run(self) -> int: json_state = self.alloc.to_json() json_result = self.result.to_json() - json_output = {} + json_output: dict[str, object] = {} if self.options.output_body == "stdout": txs_rlp = "0x" + rlp.encode(self.txs.all_txs).hex() @@ -347,6 +379,19 @@ def run(self) -> int: json.dump(json_result, f, indent=4) self.logger.info(f"Wrote result to {result_output_path}") + if self.options.opcode_count == "stdout": + opcode_count_results = self._tracer(CountTracer).results() + json_output["opcodeCount"] = opcode_count_results + elif self.options.opcode_count is not None: + opcode_count_results = self._tracer(CountTracer).results() + result_output_path = os.path.join( + self.options.output_basedir, + self.options.opcode_count, + ) + with open(result_output_path, "w") as f: + json.dump(opcode_count_results, f, indent=4) + self.logger.info(f"Wrote opcode counts to {result_output_path}") + if json_output: json.dump(json_output, self.out_file, indent=4) diff --git a/src/ethereum_spec_tools/evm_tools/t8n/evm_trace/count.py b/src/ethereum_spec_tools/evm_tools/t8n/evm_trace/count.py new file mode 100644 index 0000000000..2d953b99fb --- /dev/null +++ b/src/ethereum_spec_tools/evm_tools/t8n/evm_trace/count.py @@ -0,0 +1,46 @@ +""" +EVM trace implementation that counts how many times each opcode is executed. +""" +from collections import defaultdict + +from ethereum.trace import EvmTracer, OpStart, TraceEvent + +from .protocols import Evm + + +class CountTracer(EvmTracer): + """ + EVM trace implementation that counts how many times each opcode is + executed. + """ + + transaction_environment: object | None + active_traces: defaultdict[str, int] + + def __init__(self) -> None: + self.transaction_environment = None + self.active_traces = defaultdict(lambda: 0) + + def __call__(self, evm: object, event: TraceEvent) -> None: + """ + Create a trace of the event. + """ + if not isinstance(event, OpStart): + return + + assert isinstance(evm, Evm) + + if self.transaction_environment is not evm.message.tx_env: + self.active_traces = defaultdict(lambda: 0) + self.transaction_environment = evm.message.tx_env + + self.active_traces[event.op.name] += 1 + + def results(self) -> dict[str, int]: + """ + Return and clear the current opcode counts. + """ + results = self.active_traces + self.active_traces = defaultdict(lambda: 0) + self.transaction_environment = None + return results diff --git a/src/ethereum_spec_tools/evm_tools/t8n/evm_trace/group.py b/src/ethereum_spec_tools/evm_tools/t8n/evm_trace/group.py new file mode 100644 index 0000000000..e4f0e34242 --- /dev/null +++ b/src/ethereum_spec_tools/evm_tools/t8n/evm_trace/group.py @@ -0,0 +1,38 @@ +""" +EVM trace implementation that fans out to many concrete trace implementations. +""" +from typing import Final + +from typing_extensions import override + +from ethereum.trace import EvmTracer, TraceEvent + + +class GroupTracer(EvmTracer): + """ + EVM trace implementation that fans out to many concrete trace + implementations. + """ + + tracers: Final[set[EvmTracer]] + + def __init__(self) -> None: + self.tracers = set() + + def add(self, tracer: EvmTracer) -> None: + """ + Insert a new tracer. + """ + self.tracers.add(tracer) + + @override + def __call__( + self, + evm: object, + event: TraceEvent, + ) -> None: + """ + Record a trace event. + """ + for tracer in self.tracers: + tracer(evm, event) diff --git a/tests/evm_tools/test_count_opcodes.py b/tests/evm_tools/test_count_opcodes.py new file mode 100644 index 0000000000..f89d28a0a7 --- /dev/null +++ b/tests/evm_tools/test_count_opcodes.py @@ -0,0 +1,49 @@ +import json +from io import StringIO +from pathlib import Path +from typing import Callable + +import pytest + +from ethereum_spec_tools.evm_tools import create_parser +from ethereum_spec_tools.evm_tools.t8n import T8N + +parser = create_parser() + + +@pytest.mark.evm_tools +def test_count_opcodes(root_relative: Callable[[str | Path], Path]) -> None: + base_path = root_relative( + "fixtures/evm_tools_testdata/t8n/fixtures/testdata/2" + ) + + options = parser.parse_args( + [ + "t8n", + f"--input.env={base_path / 'env.json'}", + f"--input.alloc={base_path / 'alloc.json'}", + f"--input.txs={base_path / 'txs.json'}", + "--output.result=stdout", + "--output.body=stdout", + "--output.alloc=stdout", + "--opcode.count=stdout", + "--state-test", + ] + ) + + in_file = StringIO() + out_file = StringIO() + + t8n_tool = T8N(options, out_file=out_file, in_file=in_file) + exit_code = t8n_tool.run() + assert 0 == exit_code + + results = json.loads(out_file.getvalue()) + + assert results["opcodeCount"] == { + "PUSH1": 5, + "MSTORE8": 1, + "CREATE": 1, + "ADD": 1, + "SELFDESTRUCT": 1, + } diff --git a/tests/json_infra/conftest.py b/tests/json_infra/conftest.py index b770ed67c1..a070499365 100644 --- a/tests/json_infra/conftest.py +++ b/tests/json_infra/conftest.py @@ -2,7 +2,7 @@ import shutil import tarfile from pathlib import Path -from typing import Final, Optional, Set +from typing import Callable, Final, Optional, Set import git import requests_cache @@ -11,7 +11,7 @@ from _pytest.nodes import Item from filelock import FileLock from git.exc import GitCommandError, InvalidGitRepositoryError -from pytest import Session, StashKey +from pytest import Session, StashKey, fixture from requests_cache import CachedSession from requests_cache.backends.sqlite import SQLiteCache from typing_extensions import Self @@ -27,6 +27,19 @@ def get_xdist_worker_id(request_or_session: object) -> str: # noqa: U100 return "master" +@fixture() +def root_relative() -> Callable[[str | Path], Path]: + """ + A fixture that provides a function to resolve a path relative to + `conftest.py`. + """ + + def _(path: str | Path) -> Path: + return Path(__file__).parent / path + + return _ + + def pytest_addoption(parser: Parser) -> None: """ Accept --evm-trace option in pytest.