Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
# AI
.claude/

# Benchmark fixed opcode counts
.fixed_opcode_counts.json

# C extensions
*.so

Expand Down
1 change: 1 addition & 0 deletions packages/testing/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ extract_config = "execution_testing.cli.extract_config:extract_config"
compare_fixtures = "execution_testing.cli.compare_fixtures:main"
modify_static_test_gas_limits = "execution_testing.cli.modify_static_test_gas_limits:main"
validate_changelog = "execution_testing.cli.tox_helpers:validate_changelog"
benchmark_parser = "execution_testing.cli.benchmark_parser:main"

[tool.setuptools.packages.find]
where = ["src"]
Expand Down
366 changes: 366 additions & 0 deletions packages/testing/src/execution_testing/cli/benchmark_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,366 @@
"""
Parser to analyze benchmark tests and maintain the opcode counts mapping.
This script uses Python's AST to analyze benchmark tests and generate/update
the scenario configs in `.fixed_opcode_counts.json`.
Usage:
uv run benchmark_parser # Update `.fixed_opcode_counts.json`
uv run benchmark_parser --check # Check for new/missing entries (CI)
"""

import argparse
import ast
import json
import sys
from pathlib import Path


def get_repo_root() -> Path:
"""Get the repository root directory."""
current = Path.cwd()
while current != current.parent:
if (current / "tests" / "benchmark").exists():
return current
current = current.parent
raise FileNotFoundError("Could not find repository root")


def get_benchmark_dir() -> Path:
"""Get the benchmark tests directory."""
return get_repo_root() / "tests" / "benchmark"


def get_config_file() -> Path:
"""Get the .fixed_opcode_counts.json config file path."""
return get_repo_root() / ".fixed_opcode_counts.json"


class OpcodeExtractor(ast.NodeVisitor):
"""Extract opcode parametrizations from benchmark test functions."""

def __init__(self, source_code: str):
self.source_code = source_code
self.patterns: list[str] = []

def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
"""Visit function definitions and extract opcode patterns."""
if not node.name.startswith("test_"):
return

# Check if function has benchmark_test parameter
if not self._has_benchmark_test_param(node):
return

# Filter for code generator usage (required for fixed-opcode-count mode)
if not self._uses_code_generator(node):
return

# Extract opcode parametrizations
test_name = node.name
opcodes = self._extract_opcodes(node)

if opcodes:
# Test parametrizes on opcodes - create pattern for each
for opcode in opcodes:
pattern = f"{test_name}.*{opcode}.*"
self.patterns.append(pattern)
else:
# Test doesn't parametrize on opcodes - use test name only
pattern = f"{test_name}.*"
self.patterns.append(pattern)

self.generic_visit(node)

def _has_benchmark_test_param(self, node: ast.FunctionDef) -> bool:
"""Check if function has benchmark_test parameter."""
return any(arg.arg == "benchmark_test" for arg in node.args.args)

def _uses_code_generator(self, node: ast.FunctionDef) -> bool:
"""Check if function body uses code_generator parameter."""
func_start = node.lineno - 1
func_end = node.end_lineno
if func_end is None:
return False
func_source = "\n".join(
self.source_code.splitlines()[func_start:func_end]
)
return "code_generator=" in func_source

def _extract_opcodes(self, node: ast.FunctionDef) -> list[str]:
"""Extract opcode values from @pytest.mark.parametrize decorators."""
opcodes: list[str] = []

for decorator in node.decorator_list:
if not self._is_parametrize_decorator(decorator):
continue

if not isinstance(decorator, ast.Call) or len(decorator.args) < 2:
continue

# Get parameter names (first arg)
param_names = decorator.args[0]
if isinstance(param_names, ast.Constant):
param_str = str(param_names.value).lower()
else:
continue

# Check if "opcode" is in parameter names
if "opcode" not in param_str:
continue

# Extract opcode values from second arg (the list)
param_values = decorator.args[1]
opcodes.extend(self._parse_opcode_values(param_values))

return opcodes

def _is_parametrize_decorator(self, decorator: ast.expr) -> bool:
"""Check if decorator is @pytest.mark.parametrize."""
if isinstance(decorator, ast.Call):
if isinstance(decorator.func, ast.Attribute):
if (
isinstance(decorator.func.value, ast.Attribute)
and decorator.func.value.attr == "mark"
and decorator.func.attr == "parametrize"
):
return True
return False

def _parse_opcode_values(self, values_node: ast.expr) -> list[str]:
"""Parse opcode values from the parametrize list."""
opcodes: list[str] = []

if not isinstance(values_node, (ast.List, ast.Tuple)):
return opcodes

for element in values_node.elts:
opcode_name = self._extract_opcode_name(element)
if opcode_name:
opcodes.append(opcode_name)

return opcodes

def _extract_opcode_name(self, node: ast.expr) -> str | None:
"""
Extract opcode name from various AST node types.
Supported patterns (opcode must be first element):
Case 1 - Direct opcode reference:
@pytest.mark.parametrize("opcode", [Op.ADD, Op.MUL])
Result: ["ADD", "MUL"]
Case 2a - pytest.param with direct opcode:
@pytest.mark.parametrize("opcode", [pytest.param(Op.ADD, id="add")])
Result: ["ADD"]
Case 2b - pytest.param with tuple (opcode first):
@pytest.mark.parametrize("opcode,arg", [pytest.param((Op.ADD, 123))])
Result: ["ADD"]
Case 3 - Plain tuple (opcode first):
@pytest.mark.parametrize("opcode,arg", [(Op.ADD, 123), (Op.MUL, 456)])
Result: ["ADD", "MUL"]
"""
# Case 1: Direct opcode - Op.ADD
if isinstance(node, ast.Attribute):
return node.attr

# Case 2: pytest.param(Op.ADD, ...) or pytest.param((Op.ADD, x), ...)
if isinstance(node, ast.Call):
if len(node.args) > 0:
first_arg = node.args[0]
# Case 2a: pytest.param(Op.ADD, ...)
if isinstance(first_arg, ast.Attribute):
return first_arg.attr
# Case 2b: pytest.param((Op.ADD, x), ...)
elif isinstance(first_arg, ast.Tuple) and first_arg.elts:
first_elem = first_arg.elts[0]
if isinstance(first_elem, ast.Attribute):
return first_elem.attr

# Case 3: Plain tuple - (Op.ADD, args)
if isinstance(node, ast.Tuple) and node.elts:
first_elem = node.elts[0]
if isinstance(first_elem, ast.Attribute):
return first_elem.attr

return None


def scan_benchmark_tests(
base_path: Path,
) -> tuple[dict[str, list[int]], dict[str, Path]]:
"""
Scan benchmark test files and extract opcode patterns.
Returns:
Tuple of (config, pattern_sources) where:
- config: mapping of pattern -> opcode counts
- pattern_sources: mapping of pattern -> source file path
"""
config: dict[str, list[int]] = {}
pattern_sources: dict[str, Path] = {}
default_counts = [1]

test_files = [
f
for f in base_path.rglob("test_*.py")
if "configs" not in str(f) and "stateful" not in str(f)
]

for test_file in test_files:
try:
source = test_file.read_text()
tree = ast.parse(source)

extractor = OpcodeExtractor(source)
extractor.visit(tree)

for pattern in extractor.patterns:
if pattern not in config:
config[pattern] = default_counts
pattern_sources[pattern] = test_file
except Exception as e:
print(f"Warning: Failed to parse {test_file}: {e}")
continue

return config, pattern_sources


def load_existing_config(config_file: Path) -> dict[str, list[int]]:
"""Load existing config from .fixed_opcode_counts.json."""
if not config_file.exists():
return {}

try:
data = json.loads(config_file.read_text())
return data.get("scenario_configs", {})
except (json.JSONDecodeError, KeyError):
return {}


def categorize_patterns(
config: dict[str, list[int]], pattern_sources: dict[str, Path]
) -> dict[str, list[str]]:
"""
Categorize patterns by deriving category from source file name.
Example: test_arithmetic.py -> ARITHMETIC
"""
categories: dict[str, list[str]] = {}

for pattern in config.keys():
if pattern in pattern_sources:
source_file = pattern_sources[pattern]
file_name = source_file.stem
if file_name.startswith("test_"):
category = file_name[5:].upper() # Remove "test_" prefix
else:
category = "OTHER"
else:
category = "OTHER"

if category not in categories:
categories[category] = []
categories[category].append(pattern)

return {k: sorted(v) for k, v in sorted(categories.items())}


def generate_config_json(
config: dict[str, list[int]],
pattern_sources: dict[str, Path],
) -> str:
"""Generate the JSON config file content."""
categories = categorize_patterns(config, pattern_sources)

scenario_configs: dict[str, list[int]] = {}
for _, patterns in categories.items():
for pattern in patterns:
scenario_configs[pattern] = config[pattern]

output = {"scenario_configs": scenario_configs}

return json.dumps(output, indent=2) + "\n"


def main() -> int:
"""Main entry point."""
parser = argparse.ArgumentParser(
description="Analyze benchmark tests and maintain opcode count mapping"
)
parser.add_argument(
"--check",
action="store_true",
help="Check for new/missing entries (CI mode, exits 1 if out of sync)",
)
args = parser.parse_args()

try:
benchmark_dir = get_benchmark_dir()
config_file = get_config_file()
except FileNotFoundError as e:
print(f"Error: {e}", file=sys.stderr)
return 1

print(f"Scanning benchmark tests in {benchmark_dir}...")
detected, pattern_sources = scan_benchmark_tests(benchmark_dir)
print(f"Detected {len(detected)} opcode patterns")

existing = load_existing_config(config_file)
print(f"Loaded {len(existing)} existing entries")

detected_keys = set(detected.keys())
existing_keys = set(existing.keys())
new_patterns = sorted(detected_keys - existing_keys)
obsolete_patterns = sorted(existing_keys - detected_keys)

merged = detected.copy()
for pattern, counts in existing.items():
if pattern in detected_keys:
merged[pattern] = counts

print("\n" + "=" * 60)
print(f"Detected {len(detected)} patterns in tests")
print(f"Existing entries: {len(existing)}")

if new_patterns:
print(f"\n+ Found {len(new_patterns)} NEW patterns:")
for p in new_patterns[:15]:
print(f" {p}")
if len(new_patterns) > 15:
print(f" ... and {len(new_patterns) - 15} more")

if obsolete_patterns:
print(f"\n- Found {len(obsolete_patterns)} OBSOLETE patterns:")
for p in obsolete_patterns[:15]:
print(f" {p}")
if len(obsolete_patterns) > 15:
print(f" ... and {len(obsolete_patterns) - 15} more")

if not new_patterns and not obsolete_patterns:
print("\nConfiguration is up to date!")

print("=" * 60)

if args.check:
if new_patterns or obsolete_patterns:
print("\nRun 'uv run benchmark_parser' (without --check) to sync.")
return 1
return 0

for pattern in obsolete_patterns:
print(f"Removing obsolete: {pattern}")
if pattern in merged:
del merged[pattern]

content = generate_config_json(merged, pattern_sources)
config_file.write_text(content)
print(f"\nUpdated {config_file}")
return 0


if __name__ == "__main__":
raise SystemExit(main())
Loading
Loading