Skip to content

Commit 14b6b64

Browse files
committed
enhance(test-benchmark): use config file for fixed opcode count scenarios
1 parent 8f2639b commit 14b6b64

File tree

5 files changed

+502
-39
lines changed

5 files changed

+502
-39
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
# AI
1010
.claude/
1111

12+
# Benchmark fixed opcode counts
13+
.fixed_opcode_counts.json
14+
1215
# C extensions
1316
*.so
1417

packages/testing/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ extract_config = "execution_testing.cli.extract_config:extract_config"
101101
compare_fixtures = "execution_testing.cli.compare_fixtures:main"
102102
modify_static_test_gas_limits = "execution_testing.cli.modify_static_test_gas_limits:main"
103103
validate_changelog = "execution_testing.cli.tox_helpers:validate_changelog"
104+
benchmark_parser = "execution_testing.cli.benchmark_parser:main"
104105

105106
[tool.setuptools.packages.find]
106107
where = ["src"]
Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,366 @@
1+
"""
2+
Parser to analyze benchmark tests and maintain the opcode counts mapping.
3+
4+
This script uses Python's AST to analyze benchmark tests and generate/update
5+
the scenario configs in `.fixed_opcode_counts.json`.
6+
7+
Usage:
8+
uv run benchmark_parser # Update `.fixed_opcode_counts.json`
9+
uv run benchmark_parser --check # Check for new/missing entries (CI)
10+
"""
11+
12+
import argparse
13+
import ast
14+
import json
15+
import sys
16+
from pathlib import Path
17+
18+
19+
def get_repo_root() -> Path:
20+
"""Get the repository root directory."""
21+
current = Path.cwd()
22+
while current != current.parent:
23+
if (current / "tests" / "benchmark").exists():
24+
return current
25+
current = current.parent
26+
raise FileNotFoundError("Could not find repository root")
27+
28+
29+
def get_benchmark_dir() -> Path:
30+
"""Get the benchmark tests directory."""
31+
return get_repo_root() / "tests" / "benchmark"
32+
33+
34+
def get_config_file() -> Path:
35+
"""Get the .fixed_opcode_counts.json config file path."""
36+
return get_repo_root() / ".fixed_opcode_counts.json"
37+
38+
39+
class OpcodeExtractor(ast.NodeVisitor):
40+
"""Extract opcode parametrizations from benchmark test functions."""
41+
42+
def __init__(self, source_code: str):
43+
self.source_code = source_code
44+
self.patterns: list[str] = []
45+
46+
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
47+
"""Visit function definitions and extract opcode patterns."""
48+
if not node.name.startswith("test_"):
49+
return
50+
51+
# Check if function has benchmark_test parameter
52+
if not self._has_benchmark_test_param(node):
53+
return
54+
55+
# Filter for code generator usage (required for fixed-opcode-count mode)
56+
if not self._uses_code_generator(node):
57+
return
58+
59+
# Extract opcode parametrizations
60+
test_name = node.name
61+
opcodes = self._extract_opcodes(node)
62+
63+
if opcodes:
64+
# Test parametrizes on opcodes - create pattern for each
65+
for opcode in opcodes:
66+
pattern = f"{test_name}.*{opcode}.*"
67+
self.patterns.append(pattern)
68+
else:
69+
# Test doesn't parametrize on opcodes - use test name only
70+
pattern = f"{test_name}.*"
71+
self.patterns.append(pattern)
72+
73+
self.generic_visit(node)
74+
75+
def _has_benchmark_test_param(self, node: ast.FunctionDef) -> bool:
76+
"""Check if function has benchmark_test parameter."""
77+
return any(arg.arg == "benchmark_test" for arg in node.args.args)
78+
79+
def _uses_code_generator(self, node: ast.FunctionDef) -> bool:
80+
"""Check if function body uses code_generator parameter."""
81+
func_start = node.lineno - 1
82+
func_end = node.end_lineno
83+
if func_end is None:
84+
return False
85+
func_source = "\n".join(
86+
self.source_code.splitlines()[func_start:func_end]
87+
)
88+
return "code_generator=" in func_source
89+
90+
def _extract_opcodes(self, node: ast.FunctionDef) -> list[str]:
91+
"""Extract opcode values from @pytest.mark.parametrize decorators."""
92+
opcodes: list[str] = []
93+
94+
for decorator in node.decorator_list:
95+
if not self._is_parametrize_decorator(decorator):
96+
continue
97+
98+
if not isinstance(decorator, ast.Call) or len(decorator.args) < 2:
99+
continue
100+
101+
# Get parameter names (first arg)
102+
param_names = decorator.args[0]
103+
if isinstance(param_names, ast.Constant):
104+
param_str = str(param_names.value).lower()
105+
else:
106+
continue
107+
108+
# Check if "opcode" is in parameter names
109+
if "opcode" not in param_str:
110+
continue
111+
112+
# Extract opcode values from second arg (the list)
113+
param_values = decorator.args[1]
114+
opcodes.extend(self._parse_opcode_values(param_values))
115+
116+
return opcodes
117+
118+
def _is_parametrize_decorator(self, decorator: ast.expr) -> bool:
119+
"""Check if decorator is @pytest.mark.parametrize."""
120+
if isinstance(decorator, ast.Call):
121+
if isinstance(decorator.func, ast.Attribute):
122+
if (
123+
isinstance(decorator.func.value, ast.Attribute)
124+
and decorator.func.value.attr == "mark"
125+
and decorator.func.attr == "parametrize"
126+
):
127+
return True
128+
return False
129+
130+
def _parse_opcode_values(self, values_node: ast.expr) -> list[str]:
131+
"""Parse opcode values from the parametrize list."""
132+
opcodes: list[str] = []
133+
134+
if not isinstance(values_node, (ast.List, ast.Tuple)):
135+
return opcodes
136+
137+
for element in values_node.elts:
138+
opcode_name = self._extract_opcode_name(element)
139+
if opcode_name:
140+
opcodes.append(opcode_name)
141+
142+
return opcodes
143+
144+
def _extract_opcode_name(self, node: ast.expr) -> str | None:
145+
"""
146+
Extract opcode name from various AST node types.
147+
148+
Supported patterns (opcode must be first element):
149+
150+
Case 1 - Direct opcode reference:
151+
@pytest.mark.parametrize("opcode", [Op.ADD, Op.MUL])
152+
Result: ["ADD", "MUL"]
153+
154+
Case 2a - pytest.param with direct opcode:
155+
@pytest.mark.parametrize("opcode", [pytest.param(Op.ADD, id="add")])
156+
Result: ["ADD"]
157+
158+
Case 2b - pytest.param with tuple (opcode first):
159+
@pytest.mark.parametrize("opcode,arg", [pytest.param((Op.ADD, 123))])
160+
Result: ["ADD"]
161+
162+
Case 3 - Plain tuple (opcode first):
163+
@pytest.mark.parametrize("opcode,arg", [(Op.ADD, 123), (Op.MUL, 456)])
164+
Result: ["ADD", "MUL"]
165+
"""
166+
# Case 1: Direct opcode - Op.ADD
167+
if isinstance(node, ast.Attribute):
168+
return node.attr
169+
170+
# Case 2: pytest.param(Op.ADD, ...) or pytest.param((Op.ADD, x), ...)
171+
if isinstance(node, ast.Call):
172+
if len(node.args) > 0:
173+
first_arg = node.args[0]
174+
# Case 2a: pytest.param(Op.ADD, ...)
175+
if isinstance(first_arg, ast.Attribute):
176+
return first_arg.attr
177+
# Case 2b: pytest.param((Op.ADD, x), ...)
178+
elif isinstance(first_arg, ast.Tuple) and first_arg.elts:
179+
first_elem = first_arg.elts[0]
180+
if isinstance(first_elem, ast.Attribute):
181+
return first_elem.attr
182+
183+
# Case 3: Plain tuple - (Op.ADD, args)
184+
if isinstance(node, ast.Tuple) and node.elts:
185+
first_elem = node.elts[0]
186+
if isinstance(first_elem, ast.Attribute):
187+
return first_elem.attr
188+
189+
return None
190+
191+
192+
def scan_benchmark_tests(
193+
base_path: Path,
194+
) -> tuple[dict[str, list[int]], dict[str, Path]]:
195+
"""
196+
Scan benchmark test files and extract opcode patterns.
197+
198+
Returns:
199+
Tuple of (config, pattern_sources) where:
200+
- config: mapping of pattern -> opcode counts
201+
- pattern_sources: mapping of pattern -> source file path
202+
"""
203+
config: dict[str, list[int]] = {}
204+
pattern_sources: dict[str, Path] = {}
205+
default_counts = [1]
206+
207+
test_files = [
208+
f
209+
for f in base_path.rglob("test_*.py")
210+
if "configs" not in str(f) and "stateful" not in str(f)
211+
]
212+
213+
for test_file in test_files:
214+
try:
215+
source = test_file.read_text()
216+
tree = ast.parse(source)
217+
218+
extractor = OpcodeExtractor(source)
219+
extractor.visit(tree)
220+
221+
for pattern in extractor.patterns:
222+
if pattern not in config:
223+
config[pattern] = default_counts
224+
pattern_sources[pattern] = test_file
225+
except Exception as e:
226+
print(f"Warning: Failed to parse {test_file}: {e}")
227+
continue
228+
229+
return config, pattern_sources
230+
231+
232+
def load_existing_config(config_file: Path) -> dict[str, list[int]]:
233+
"""Load existing config from .fixed_opcode_counts.json."""
234+
if not config_file.exists():
235+
return {}
236+
237+
try:
238+
data = json.loads(config_file.read_text())
239+
return data.get("scenario_configs", {})
240+
except (json.JSONDecodeError, KeyError):
241+
return {}
242+
243+
244+
def categorize_patterns(
245+
config: dict[str, list[int]], pattern_sources: dict[str, Path]
246+
) -> dict[str, list[str]]:
247+
"""
248+
Categorize patterns by deriving category from source file name.
249+
250+
Example: test_arithmetic.py -> ARITHMETIC
251+
"""
252+
categories: dict[str, list[str]] = {}
253+
254+
for pattern in config.keys():
255+
if pattern in pattern_sources:
256+
source_file = pattern_sources[pattern]
257+
file_name = source_file.stem
258+
if file_name.startswith("test_"):
259+
category = file_name[5:].upper() # Remove "test_" prefix
260+
else:
261+
category = "OTHER"
262+
else:
263+
category = "OTHER"
264+
265+
if category not in categories:
266+
categories[category] = []
267+
categories[category].append(pattern)
268+
269+
return {k: sorted(v) for k, v in sorted(categories.items())}
270+
271+
272+
def generate_config_json(
273+
config: dict[str, list[int]],
274+
pattern_sources: dict[str, Path],
275+
) -> str:
276+
"""Generate the JSON config file content."""
277+
categories = categorize_patterns(config, pattern_sources)
278+
279+
scenario_configs: dict[str, list[int]] = {}
280+
for _, patterns in categories.items():
281+
for pattern in patterns:
282+
scenario_configs[pattern] = config[pattern]
283+
284+
output = {"scenario_configs": scenario_configs}
285+
286+
return json.dumps(output, indent=2) + "\n"
287+
288+
289+
def main() -> int:
290+
"""Main entry point."""
291+
parser = argparse.ArgumentParser(
292+
description="Analyze benchmark tests and maintain opcode count mapping"
293+
)
294+
parser.add_argument(
295+
"--check",
296+
action="store_true",
297+
help="Check for new/missing entries (CI mode, exits 1 if out of sync)",
298+
)
299+
args = parser.parse_args()
300+
301+
try:
302+
benchmark_dir = get_benchmark_dir()
303+
config_file = get_config_file()
304+
except FileNotFoundError as e:
305+
print(f"Error: {e}", file=sys.stderr)
306+
return 1
307+
308+
print(f"Scanning benchmark tests in {benchmark_dir}...")
309+
detected, pattern_sources = scan_benchmark_tests(benchmark_dir)
310+
print(f"Detected {len(detected)} opcode patterns")
311+
312+
existing = load_existing_config(config_file)
313+
print(f"Loaded {len(existing)} existing entries")
314+
315+
detected_keys = set(detected.keys())
316+
existing_keys = set(existing.keys())
317+
new_patterns = sorted(detected_keys - existing_keys)
318+
obsolete_patterns = sorted(existing_keys - detected_keys)
319+
320+
merged = detected.copy()
321+
for pattern, counts in existing.items():
322+
if pattern in detected_keys:
323+
merged[pattern] = counts
324+
325+
print("\n" + "=" * 60)
326+
print(f"Detected {len(detected)} patterns in tests")
327+
print(f"Existing entries: {len(existing)}")
328+
329+
if new_patterns:
330+
print(f"\n+ Found {len(new_patterns)} NEW patterns:")
331+
for p in new_patterns[:15]:
332+
print(f" {p}")
333+
if len(new_patterns) > 15:
334+
print(f" ... and {len(new_patterns) - 15} more")
335+
336+
if obsolete_patterns:
337+
print(f"\n- Found {len(obsolete_patterns)} OBSOLETE patterns:")
338+
for p in obsolete_patterns[:15]:
339+
print(f" {p}")
340+
if len(obsolete_patterns) > 15:
341+
print(f" ... and {len(obsolete_patterns) - 15} more")
342+
343+
if not new_patterns and not obsolete_patterns:
344+
print("\nConfiguration is up to date!")
345+
346+
print("=" * 60)
347+
348+
if args.check:
349+
if new_patterns or obsolete_patterns:
350+
print("\nRun 'uv run benchmark_parser' (without --check) to sync.")
351+
return 1
352+
return 0
353+
354+
for pattern in obsolete_patterns:
355+
print(f"Removing obsolete: {pattern}")
356+
if pattern in merged:
357+
del merged[pattern]
358+
359+
content = generate_config_json(merged, pattern_sources)
360+
config_file.write_text(content)
361+
print(f"\nUpdated {config_file}")
362+
return 0
363+
364+
365+
if __name__ == "__main__":
366+
raise SystemExit(main())

0 commit comments

Comments
 (0)