Skip to content

Commit 69bb225

Browse files
committed
WiP: Add flag to count opcodes during execution
1 parent 87fb278 commit 69bb225

File tree

3 files changed

+121
-3
lines changed

3 files changed

+121
-3
lines changed

src/ethereum_spec_tools/evm_tools/t8n/__init__.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import fnmatch
77
import json
88
import os
9-
from typing import Any, TextIO
9+
from typing import Any, Final, TextIO
1010

1111
from ethereum_rlp import rlp
1212
from ethereum_types.numeric import U64, U256, Uint
@@ -24,7 +24,10 @@
2424
parse_hex_or_int,
2525
)
2626
from .env import Env
27+
from .evm_trace.count import evm_trace as evm_trace_count
28+
from .evm_trace.count import results as count_results
2729
from .evm_trace.eip3155 import Eip3155Tracer
30+
from .evm_trace.group import GroupTracer
2831
from .t8n_types import Alloc, Result, Txs
2932

3033

@@ -72,12 +75,16 @@ def t8n_arguments(subparsers: argparse._SubParsersAction) -> None:
7275
t8n_parser.add_argument("--trace.nostack", action="store_true")
7376
t8n_parser.add_argument("--trace.returndata", action="store_true")
7477

78+
t8n_parser.add_argument("--opcode.count", dest="opcode_count", type=str)
79+
7580
t8n_parser.add_argument("--state-test", action="store_true")
7681

7782

7883
class T8N(Load):
7984
"""The class that carries out the transition"""
8085

86+
tracers: Final[GroupTracer | None]
87+
8188
def __init__(
8289
self, options: Any, out_file: TextIO, in_file: TextIO
8390
) -> None:
@@ -100,18 +107,33 @@ def __init__(
100107
)
101108
self.fork = ForkLoad(fork_module)
102109

110+
tracers = GroupTracer()
111+
103112
if self.options.trace:
104113
trace_memory = getattr(self.options, "trace.memory", False)
105114
trace_stack = not getattr(self.options, "trace.nostack", False)
106115
trace_return_data = getattr(self.options, "trace.returndata")
107-
trace.set_evm_trace(
116+
tracers.add(
108117
Eip3155Tracer(
109118
trace_memory=trace_memory,
110119
trace_stack=trace_stack,
111120
trace_return_data=trace_return_data,
112121
output_basedir=self.options.output_basedir,
113122
)
114123
)
124+
125+
if self.options.opcode_count is not None:
126+
tracers.add(evm_trace_count)
127+
128+
maybe_tracers: GroupTracer | None
129+
if tracers.tracers:
130+
trace.set_evm_trace(tracers)
131+
maybe_tracers = tracers
132+
else:
133+
maybe_tracers = None
134+
135+
self.tracers = maybe_tracers
136+
115137
self.logger = get_stream_logger("T8N")
116138

117139
super().__init__(
@@ -310,7 +332,7 @@ def run(self) -> int:
310332
json_state = self.alloc.to_json()
311333
json_result = self.result.to_json()
312334

313-
json_output = {}
335+
json_output: dict[str, object] = {}
314336

315337
if self.options.output_body == "stdout":
316338
txs_rlp = "0x" + rlp.encode(self.txs.all_txs).hex()
@@ -347,6 +369,18 @@ def run(self) -> int:
347369
json.dump(json_result, f, indent=4)
348370
self.logger.info(f"Wrote result to {result_output_path}")
349371

372+
opcode_count_results = count_results()
373+
if self.options.opcode_count == "stdout":
374+
json_output["opcodes"] = opcode_count_results
375+
elif self.options.opcode_count is not None:
376+
result_output_path = os.path.join(
377+
self.options.output_basedir,
378+
self.options.opcode_count,
379+
)
380+
with open(result_output_path, "w") as f:
381+
json.dump(opcode_count_results, f, indent=4)
382+
self.logger.info(f"Wrote opcode counts to {result_output_path}")
383+
350384
if json_output:
351385
json.dump(json_output, self.out_file, indent=4)
352386

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""
2+
EVM trace implementation that counts how many times each opcode is executed.
3+
"""
4+
from collections import defaultdict
5+
from typing import Any, TypeAlias
6+
7+
from ethereum.trace import OpStart, TraceEvent
8+
9+
from .protocols import Evm
10+
11+
_ActiveTraces: TypeAlias = tuple[object, dict[str, int]]
12+
_active_traces: _ActiveTraces | None = None
13+
14+
15+
def evm_trace(
16+
evm: Any,
17+
event: TraceEvent,
18+
) -> None:
19+
"""
20+
Create a trace of the event.
21+
"""
22+
global _active_traces
23+
24+
if not isinstance(event, OpStart):
25+
return
26+
27+
assert isinstance(evm, Evm)
28+
29+
if _active_traces and _active_traces[0] is evm.message.tx_env:
30+
traces = _active_traces[1]
31+
else:
32+
traces = defaultdict(lambda: 0)
33+
_active_traces = (evm.message.tx_env, traces)
34+
35+
traces[event.op.name] += 1
36+
37+
38+
def results() -> dict[str, int]:
39+
"""
40+
Take and clear the current opcode counts.
41+
"""
42+
global _active_traces
43+
44+
results = _active_traces
45+
_active_traces = None
46+
return {} if results is None else results[1]
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""
2+
EVM trace implementation that fans out to many concrete trace implementations.
3+
"""
4+
from typing import Final
5+
6+
from typing_extensions import override
7+
8+
from ethereum.trace import EvmTracer, TraceEvent
9+
10+
11+
class GroupTracer(EvmTracer):
12+
"""
13+
EVM trace implementation that fans out to many concrete trace
14+
implementations.
15+
"""
16+
17+
tracers: Final[set[EvmTracer]]
18+
19+
def __init__(self) -> None:
20+
self.tracers = set()
21+
22+
def add(self, tracer: EvmTracer) -> None:
23+
"""
24+
Insert a new tracer.
25+
"""
26+
self.tracers.add(tracer)
27+
28+
@override
29+
def __call__(
30+
self,
31+
evm: object,
32+
event: TraceEvent,
33+
) -> None:
34+
"""
35+
Record a trace event.
36+
"""
37+
for tracer in self.tracers:
38+
tracer(evm, event)

0 commit comments

Comments
 (0)