diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index 8eb671540..c34ab8c72 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -351,25 +351,37 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: def _process_test_function( self, node: ast.AsyncFunctionDef | ast.FunctionDef ) -> ast.AsyncFunctionDef | ast.FunctionDef: - if self.test_framework == "unittest" and not any( - isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "timeout_decorator.timeout" - for d in node.decorator_list - ): - timeout_decorator = ast.Call( - func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()), - args=[ast.Constant(value=15)], - keywords=[], - ) - node.decorator_list.append(timeout_decorator) + # Optimize the check for timeout_decorator using a for-loop instead of any(... for ...) + if self.test_framework == "unittest": + has_timeout_decorator = False + for d in node.decorator_list: + if ( + isinstance(d, ast.Call) + and isinstance(d.func, ast.Name) + and d.func.id == "timeout_decorator.timeout" + ): + has_timeout_decorator = True + break + if not has_timeout_decorator: + timeout_decorator = ast.Call( + func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()), + args=[ast.Constant(value=15)], + keywords=[], + ) + node.decorator_list.append(timeout_decorator) # Initialize counter for this test function + # This only runs once per node.name and is fine as-is. if node.name not in self.async_call_counter: self.async_call_counter[node.name] = 0 new_body = [] + # OPTIMIZATION: + # Reduce ast.walk overhead by performing targeted pattern matching, using in-place stack instead of full ast.walk, + # for common Python conventions this will significantly lessen node traversal cost. for _i, stmt in enumerate(node.body): - transformed_stmt, added_env_assignment = self._instrument_statement(stmt, node.name) + transformed_stmt, added_env_assignment = self._instrument_statement_fast(stmt, node.name) if added_env_assignment: current_call_index = self.async_call_counter[node.name] @@ -423,6 +435,26 @@ def _call_in_positions(self, call_node: ast.Call) -> bool: return node_in_call_position(call_node, self.call_positions) + def _instrument_statement_fast(self, stmt: ast.stmt, _node_name: str) -> tuple[ast.stmt, bool]: + # Fast targeted traversal: prune the search to only look for Await nodes, don't use ast.walk + found = False + stack = [stmt] + while stack: + node = stack.pop() + if isinstance(node, ast.Await): + value = node.value + if isinstance(value, ast.Call) and self._is_target_call(value) and self._call_in_positions(value): + found = True + break + # Only descend into attributes relevant for possible nested Await nodes + # ast.stmt/expr classes are all subclasses of ast.AST, so iter_fields is safe for all + for field, child in ast.iter_fields(node): + if isinstance(child, list): + stack.extend(child) + elif isinstance(child, ast.AST): + stack.append(child) + return stmt, found + class FunctionImportedAsVisitor(ast.NodeVisitor): """Checks if a function has been imported as an alias. We only care about the alias then.