diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index db08f8afc..7e7940ddd 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -351,16 +351,24 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: def _process_test_function( self, node: ast.AsyncFunctionDef | ast.FunctionDef ) -> ast.AsyncFunctionDef | ast.FunctionDef: - if self.test_framework == "unittest" and not any( - isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "timeout_decorator.timeout" - for d in node.decorator_list - ): - timeout_decorator = ast.Call( - func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()), - args=[ast.Constant(value=15)], - keywords=[], - ) - node.decorator_list.append(timeout_decorator) + # Optimize the search for decorator presence + if self.test_framework == "unittest": + found_timeout = False + for d in node.decorator_list: + # Avoid isinstance(d.func, ast.Name) if d is not ast.Call + if isinstance(d, ast.Call): + f = d.func + # Avoid attribute lookup if f is not ast.Name + if isinstance(f, ast.Name) and f.id == "timeout_decorator.timeout": + found_timeout = True + break + if not found_timeout: + timeout_decorator = ast.Call( + func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()), + args=[ast.Constant(value=15)], + keywords=[], + ) + node.decorator_list.append(timeout_decorator) # Initialize counter for this test function if node.name not in self.async_call_counter: @@ -368,8 +376,9 @@ def _process_test_function( new_body = [] + # Optimize ast.walk calls inside _instrument_statement, by scanning only relevant nodes for _i, stmt in enumerate(node.body): - transformed_stmt, added_env_assignment = self._instrument_statement(stmt, node.name) + transformed_stmt, added_env_assignment = self._optimized_instrument_statement(stmt) if added_env_assignment: current_call_index = self.async_call_counter[node.name] @@ -423,6 +432,26 @@ def _call_in_positions(self, call_node: ast.Call) -> bool: return node_in_call_position(call_node, self.call_positions) + # Optimized version: only walk child nodes for Await + def _optimized_instrument_statement(self, stmt: ast.stmt) -> tuple[ast.stmt, bool]: + # Stack-based DFS, manual for relevant Await nodes + stack = [stmt] + while stack: + node = stack.pop() + # Favor direct ast.Await detection + if isinstance(node, ast.Await): + val = node.value + if isinstance(val, ast.Call) and self._is_target_call(val) and self._call_in_positions(val): + return stmt, True + # Use _fields instead of ast.walk for less allocations + for fname in getattr(node, "_fields", ()): + child = getattr(node, fname, None) + if isinstance(child, list): + stack.extend(child) + elif isinstance(child, ast.AST): + stack.append(child) + return stmt, False + class FunctionImportedAsVisitor(ast.NodeVisitor): """Checks if a function has been imported as an alias. We only care about the alias then.