|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import ast |
| 4 | +from _ast import AST |
4 | 5 | from collections import defaultdict |
5 | 6 | from functools import lru_cache |
6 | 7 | from typing import TYPE_CHECKING, Optional, TypeVar, Union |
@@ -40,8 +41,8 @@ def _uses_benchmark_fixture(self, node: Union[ast.FunctionDef, ast.AsyncFunction |
40 | 41 | if self._is_benchmark_marker(decorator): |
41 | 42 | return True |
42 | 43 |
|
43 | | - # Check function body for benchmark usage |
44 | | - return any(isinstance(stmt, ast.Call) and self._is_benchmark_call(stmt) for stmt in ast.walk(node)) |
| 44 | + # Optimized: Use a fast body scan to detect use of benchmark in function body |
| 45 | + return self._body_uses_benchmark_call(node.body) |
45 | 46 |
|
46 | 47 | def _is_benchmark_marker(self, decorator: ast.expr) -> bool: |
47 | 48 | """Check if decorator is a benchmark-related pytest marker.""" |
@@ -113,6 +114,29 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: |
113 | 114 | node.body = new_body |
114 | 115 | return node |
115 | 116 |
|
| 117 | + def _body_uses_benchmark_call(self, stmts): |
| 118 | + """Efficiently check if 'benchmark' is called anywhere in the body (recursive, shallow, single function only).""" |
| 119 | + stack = list(stmts) |
| 120 | + while stack: |
| 121 | + stmt = stack.pop() |
| 122 | + # Check for a benchmark call at this node (stmt may be an expr, an Assign, etc.) |
| 123 | + if isinstance(stmt, ast.Expr) and isinstance(stmt.value, ast.Call): |
| 124 | + if self._is_benchmark_call(stmt.value): |
| 125 | + return True |
| 126 | + elif isinstance(stmt, ast.Call): |
| 127 | + if self._is_benchmark_call(stmt): |
| 128 | + return True |
| 129 | + # Recursively check relevant AST containers for body calls |
| 130 | + for attr in ("body", "orelse", "finalbody"): |
| 131 | + if hasattr(stmt, attr): |
| 132 | + stack.extend(getattr(stmt, attr)) |
| 133 | + # Check except blocks for 'body' |
| 134 | + if hasattr(stmt, "handlers"): |
| 135 | + for handler in stmt.handlers: |
| 136 | + if hasattr(handler, "body"): |
| 137 | + stack.extend(handler.body) |
| 138 | + return False |
| 139 | + |
116 | 140 |
|
117 | 141 | def remove_benchmark_functions(tree: AST) -> AST: |
118 | 142 | """Remove benchmark functions from Python source code. |
|
0 commit comments