|
| 1 | +""" |
| 2 | +Parser to analyze benchmark tests and maintain the opcode counts mapping. |
| 3 | +
|
| 4 | +This script uses Python's AST to analyze benchmark tests and generate/update |
| 5 | +the scenario configs in `.fixed_opcode_counts.json`. |
| 6 | +
|
| 7 | +Usage: |
| 8 | + uv run benchmark_parser # Update `.fixed_opcode_counts.json` |
| 9 | + uv run benchmark_parser --check # Check for new/missing entries (CI) |
| 10 | +""" |
| 11 | + |
| 12 | +import argparse |
| 13 | +import ast |
| 14 | +import json |
| 15 | +import sys |
| 16 | +from pathlib import Path |
| 17 | + |
| 18 | + |
| 19 | +def get_repo_root() -> Path: |
| 20 | + """Get the repository root directory.""" |
| 21 | + current = Path.cwd() |
| 22 | + while current != current.parent: |
| 23 | + if (current / "tests" / "benchmark").exists(): |
| 24 | + return current |
| 25 | + current = current.parent |
| 26 | + raise FileNotFoundError("Could not find repository root") |
| 27 | + |
| 28 | + |
| 29 | +def get_benchmark_dir() -> Path: |
| 30 | + """Get the benchmark tests directory.""" |
| 31 | + return get_repo_root() / "tests" / "benchmark" |
| 32 | + |
| 33 | + |
| 34 | +def get_config_file() -> Path: |
| 35 | + """Get the .fixed_opcode_counts.json config file path.""" |
| 36 | + return get_repo_root() / ".fixed_opcode_counts.json" |
| 37 | + |
| 38 | + |
| 39 | +class OpcodeExtractor(ast.NodeVisitor): |
| 40 | + """Extract opcode parametrizations from benchmark test functions.""" |
| 41 | + |
| 42 | + def __init__(self, source_code: str): |
| 43 | + self.source_code = source_code |
| 44 | + self.patterns: list[str] = [] |
| 45 | + |
| 46 | + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: |
| 47 | + """Visit function definitions and extract opcode patterns.""" |
| 48 | + if not node.name.startswith("test_"): |
| 49 | + return |
| 50 | + |
| 51 | + # Check if function has benchmark_test parameter |
| 52 | + if not self._has_benchmark_test_param(node): |
| 53 | + return |
| 54 | + |
| 55 | + # Filter for code generator usage (required for fixed-opcode-count mode) |
| 56 | + if not self._uses_code_generator(node): |
| 57 | + return |
| 58 | + |
| 59 | + # Extract opcode parametrizations |
| 60 | + test_name = node.name |
| 61 | + opcodes = self._extract_opcodes(node) |
| 62 | + |
| 63 | + if opcodes: |
| 64 | + # Test parametrizes on opcodes - create pattern for each |
| 65 | + for opcode in opcodes: |
| 66 | + pattern = f"{test_name}.*{opcode}.*" |
| 67 | + self.patterns.append(pattern) |
| 68 | + else: |
| 69 | + # Test doesn't parametrize on opcodes - use test name only |
| 70 | + pattern = f"{test_name}.*" |
| 71 | + self.patterns.append(pattern) |
| 72 | + |
| 73 | + self.generic_visit(node) |
| 74 | + |
| 75 | + def _has_benchmark_test_param(self, node: ast.FunctionDef) -> bool: |
| 76 | + """Check if function has benchmark_test parameter.""" |
| 77 | + return any(arg.arg == "benchmark_test" for arg in node.args.args) |
| 78 | + |
| 79 | + def _uses_code_generator(self, node: ast.FunctionDef) -> bool: |
| 80 | + """Check if function body uses code_generator parameter.""" |
| 81 | + func_start = node.lineno - 1 |
| 82 | + func_end = node.end_lineno |
| 83 | + if func_end is None: |
| 84 | + return False |
| 85 | + func_source = "\n".join( |
| 86 | + self.source_code.splitlines()[func_start:func_end] |
| 87 | + ) |
| 88 | + return "code_generator=" in func_source |
| 89 | + |
| 90 | + def _extract_opcodes(self, node: ast.FunctionDef) -> list[str]: |
| 91 | + """Extract opcode values from @pytest.mark.parametrize decorators.""" |
| 92 | + opcodes: list[str] = [] |
| 93 | + |
| 94 | + for decorator in node.decorator_list: |
| 95 | + if not self._is_parametrize_decorator(decorator): |
| 96 | + continue |
| 97 | + |
| 98 | + if not isinstance(decorator, ast.Call) or len(decorator.args) < 2: |
| 99 | + continue |
| 100 | + |
| 101 | + # Get parameter names (first arg) |
| 102 | + param_names = decorator.args[0] |
| 103 | + if isinstance(param_names, ast.Constant): |
| 104 | + param_str = str(param_names.value).lower() |
| 105 | + else: |
| 106 | + continue |
| 107 | + |
| 108 | + # Check if "opcode" is in parameter names |
| 109 | + if "opcode" not in param_str: |
| 110 | + continue |
| 111 | + |
| 112 | + # Extract opcode values from second arg (the list) |
| 113 | + param_values = decorator.args[1] |
| 114 | + opcodes.extend(self._parse_opcode_values(param_values)) |
| 115 | + |
| 116 | + return opcodes |
| 117 | + |
| 118 | + def _is_parametrize_decorator(self, decorator: ast.expr) -> bool: |
| 119 | + """Check if decorator is @pytest.mark.parametrize.""" |
| 120 | + if isinstance(decorator, ast.Call): |
| 121 | + if isinstance(decorator.func, ast.Attribute): |
| 122 | + if ( |
| 123 | + isinstance(decorator.func.value, ast.Attribute) |
| 124 | + and decorator.func.value.attr == "mark" |
| 125 | + and decorator.func.attr == "parametrize" |
| 126 | + ): |
| 127 | + return True |
| 128 | + return False |
| 129 | + |
| 130 | + def _parse_opcode_values(self, values_node: ast.expr) -> list[str]: |
| 131 | + """Parse opcode values from the parametrize list.""" |
| 132 | + opcodes: list[str] = [] |
| 133 | + |
| 134 | + if not isinstance(values_node, (ast.List, ast.Tuple)): |
| 135 | + return opcodes |
| 136 | + |
| 137 | + for element in values_node.elts: |
| 138 | + opcode_name = self._extract_opcode_name(element) |
| 139 | + if opcode_name: |
| 140 | + opcodes.append(opcode_name) |
| 141 | + |
| 142 | + return opcodes |
| 143 | + |
| 144 | + def _extract_opcode_name(self, node: ast.expr) -> str | None: |
| 145 | + """ |
| 146 | + Extract opcode name from various AST node types. |
| 147 | +
|
| 148 | + Supported patterns (opcode must be first element): |
| 149 | +
|
| 150 | + Case 1 - Direct opcode reference: |
| 151 | + @pytest.mark.parametrize("opcode", [Op.ADD, Op.MUL]) |
| 152 | + Result: ["ADD", "MUL"] |
| 153 | +
|
| 154 | + Case 2a - pytest.param with direct opcode: |
| 155 | + @pytest.mark.parametrize("opcode", [pytest.param(Op.ADD, id="add")]) |
| 156 | + Result: ["ADD"] |
| 157 | +
|
| 158 | + Case 2b - pytest.param with tuple (opcode first): |
| 159 | + @pytest.mark.parametrize("opcode,arg", [pytest.param((Op.ADD, 123))]) |
| 160 | + Result: ["ADD"] |
| 161 | +
|
| 162 | + Case 3 - Plain tuple (opcode first): |
| 163 | + @pytest.mark.parametrize("opcode,arg", [(Op.ADD, 123), (Op.MUL, 456)]) |
| 164 | + Result: ["ADD", "MUL"] |
| 165 | + """ |
| 166 | + # Case 1: Direct opcode - Op.ADD |
| 167 | + if isinstance(node, ast.Attribute): |
| 168 | + return node.attr |
| 169 | + |
| 170 | + # Case 2: pytest.param(Op.ADD, ...) or pytest.param((Op.ADD, x), ...) |
| 171 | + if isinstance(node, ast.Call): |
| 172 | + if len(node.args) > 0: |
| 173 | + first_arg = node.args[0] |
| 174 | + # Case 2a: pytest.param(Op.ADD, ...) |
| 175 | + if isinstance(first_arg, ast.Attribute): |
| 176 | + return first_arg.attr |
| 177 | + # Case 2b: pytest.param((Op.ADD, x), ...) |
| 178 | + elif isinstance(first_arg, ast.Tuple) and first_arg.elts: |
| 179 | + first_elem = first_arg.elts[0] |
| 180 | + if isinstance(first_elem, ast.Attribute): |
| 181 | + return first_elem.attr |
| 182 | + |
| 183 | + # Case 3: Plain tuple - (Op.ADD, args) |
| 184 | + if isinstance(node, ast.Tuple) and node.elts: |
| 185 | + first_elem = node.elts[0] |
| 186 | + if isinstance(first_elem, ast.Attribute): |
| 187 | + return first_elem.attr |
| 188 | + |
| 189 | + return None |
| 190 | + |
| 191 | + |
| 192 | +def scan_benchmark_tests( |
| 193 | + base_path: Path, |
| 194 | +) -> tuple[dict[str, list[int]], dict[str, Path]]: |
| 195 | + """ |
| 196 | + Scan benchmark test files and extract opcode patterns. |
| 197 | +
|
| 198 | + Returns: |
| 199 | + Tuple of (config, pattern_sources) where: |
| 200 | + - config: mapping of pattern -> opcode counts |
| 201 | + - pattern_sources: mapping of pattern -> source file path |
| 202 | + """ |
| 203 | + config: dict[str, list[int]] = {} |
| 204 | + pattern_sources: dict[str, Path] = {} |
| 205 | + default_counts = [1] |
| 206 | + |
| 207 | + test_files = [ |
| 208 | + f |
| 209 | + for f in base_path.rglob("test_*.py") |
| 210 | + if "configs" not in str(f) and "stateful" not in str(f) |
| 211 | + ] |
| 212 | + |
| 213 | + for test_file in test_files: |
| 214 | + try: |
| 215 | + source = test_file.read_text() |
| 216 | + tree = ast.parse(source) |
| 217 | + |
| 218 | + extractor = OpcodeExtractor(source) |
| 219 | + extractor.visit(tree) |
| 220 | + |
| 221 | + for pattern in extractor.patterns: |
| 222 | + if pattern not in config: |
| 223 | + config[pattern] = default_counts |
| 224 | + pattern_sources[pattern] = test_file |
| 225 | + except Exception as e: |
| 226 | + print(f"Warning: Failed to parse {test_file}: {e}") |
| 227 | + continue |
| 228 | + |
| 229 | + return config, pattern_sources |
| 230 | + |
| 231 | + |
| 232 | +def load_existing_config(config_file: Path) -> dict[str, list[int]]: |
| 233 | + """Load existing config from .fixed_opcode_counts.json.""" |
| 234 | + if not config_file.exists(): |
| 235 | + return {} |
| 236 | + |
| 237 | + try: |
| 238 | + data = json.loads(config_file.read_text()) |
| 239 | + return data.get("scenario_configs", {}) |
| 240 | + except (json.JSONDecodeError, KeyError): |
| 241 | + return {} |
| 242 | + |
| 243 | + |
| 244 | +def categorize_patterns( |
| 245 | + config: dict[str, list[int]], pattern_sources: dict[str, Path] |
| 246 | +) -> dict[str, list[str]]: |
| 247 | + """ |
| 248 | + Categorize patterns by deriving category from source file name. |
| 249 | +
|
| 250 | + Example: test_arithmetic.py -> ARITHMETIC |
| 251 | + """ |
| 252 | + categories: dict[str, list[str]] = {} |
| 253 | + |
| 254 | + for pattern in config.keys(): |
| 255 | + if pattern in pattern_sources: |
| 256 | + source_file = pattern_sources[pattern] |
| 257 | + file_name = source_file.stem |
| 258 | + if file_name.startswith("test_"): |
| 259 | + category = file_name[5:].upper() # Remove "test_" prefix |
| 260 | + else: |
| 261 | + category = "OTHER" |
| 262 | + else: |
| 263 | + category = "OTHER" |
| 264 | + |
| 265 | + if category not in categories: |
| 266 | + categories[category] = [] |
| 267 | + categories[category].append(pattern) |
| 268 | + |
| 269 | + return {k: sorted(v) for k, v in sorted(categories.items())} |
| 270 | + |
| 271 | + |
| 272 | +def generate_config_json( |
| 273 | + config: dict[str, list[int]], |
| 274 | + pattern_sources: dict[str, Path], |
| 275 | +) -> str: |
| 276 | + """Generate the JSON config file content.""" |
| 277 | + categories = categorize_patterns(config, pattern_sources) |
| 278 | + |
| 279 | + scenario_configs: dict[str, list[int]] = {} |
| 280 | + for _, patterns in categories.items(): |
| 281 | + for pattern in patterns: |
| 282 | + scenario_configs[pattern] = config[pattern] |
| 283 | + |
| 284 | + output = {"scenario_configs": scenario_configs} |
| 285 | + |
| 286 | + return json.dumps(output, indent=2) + "\n" |
| 287 | + |
| 288 | + |
| 289 | +def main() -> int: |
| 290 | + """Main entry point.""" |
| 291 | + parser = argparse.ArgumentParser( |
| 292 | + description="Analyze benchmark tests and maintain opcode count mapping" |
| 293 | + ) |
| 294 | + parser.add_argument( |
| 295 | + "--check", |
| 296 | + action="store_true", |
| 297 | + help="Check for new/missing entries (CI mode, exits 1 if out of sync)", |
| 298 | + ) |
| 299 | + args = parser.parse_args() |
| 300 | + |
| 301 | + try: |
| 302 | + benchmark_dir = get_benchmark_dir() |
| 303 | + config_file = get_config_file() |
| 304 | + except FileNotFoundError as e: |
| 305 | + print(f"Error: {e}", file=sys.stderr) |
| 306 | + return 1 |
| 307 | + |
| 308 | + print(f"Scanning benchmark tests in {benchmark_dir}...") |
| 309 | + detected, pattern_sources = scan_benchmark_tests(benchmark_dir) |
| 310 | + print(f"Detected {len(detected)} opcode patterns") |
| 311 | + |
| 312 | + existing = load_existing_config(config_file) |
| 313 | + print(f"Loaded {len(existing)} existing entries") |
| 314 | + |
| 315 | + detected_keys = set(detected.keys()) |
| 316 | + existing_keys = set(existing.keys()) |
| 317 | + new_patterns = sorted(detected_keys - existing_keys) |
| 318 | + obsolete_patterns = sorted(existing_keys - detected_keys) |
| 319 | + |
| 320 | + merged = detected.copy() |
| 321 | + for pattern, counts in existing.items(): |
| 322 | + if pattern in detected_keys: |
| 323 | + merged[pattern] = counts |
| 324 | + |
| 325 | + print("\n" + "=" * 60) |
| 326 | + print(f"Detected {len(detected)} patterns in tests") |
| 327 | + print(f"Existing entries: {len(existing)}") |
| 328 | + |
| 329 | + if new_patterns: |
| 330 | + print(f"\n+ Found {len(new_patterns)} NEW patterns:") |
| 331 | + for p in new_patterns[:15]: |
| 332 | + print(f" {p}") |
| 333 | + if len(new_patterns) > 15: |
| 334 | + print(f" ... and {len(new_patterns) - 15} more") |
| 335 | + |
| 336 | + if obsolete_patterns: |
| 337 | + print(f"\n- Found {len(obsolete_patterns)} OBSOLETE patterns:") |
| 338 | + for p in obsolete_patterns[:15]: |
| 339 | + print(f" {p}") |
| 340 | + if len(obsolete_patterns) > 15: |
| 341 | + print(f" ... and {len(obsolete_patterns) - 15} more") |
| 342 | + |
| 343 | + if not new_patterns and not obsolete_patterns: |
| 344 | + print("\nConfiguration is up to date!") |
| 345 | + |
| 346 | + print("=" * 60) |
| 347 | + |
| 348 | + if args.check: |
| 349 | + if new_patterns or obsolete_patterns: |
| 350 | + print("\nRun 'uv run benchmark_parser' (without --check) to sync.") |
| 351 | + return 1 |
| 352 | + return 0 |
| 353 | + |
| 354 | + for pattern in obsolete_patterns: |
| 355 | + print(f"Removing obsolete: {pattern}") |
| 356 | + if pattern in merged: |
| 357 | + del merged[pattern] |
| 358 | + |
| 359 | + content = generate_config_json(merged, pattern_sources) |
| 360 | + config_file.write_text(content) |
| 361 | + print(f"\nUpdated {config_file}") |
| 362 | + return 0 |
| 363 | + |
| 364 | + |
| 365 | +if __name__ == "__main__": |
| 366 | + raise SystemExit(main()) |
0 commit comments