diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index db08f8afc..a0c44804f 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -312,6 +312,8 @@ def __init__( self.async_call_counter: dict[str, int] = {} if len(function.parents) == 1 and function.parents[0].type == "ClassDef": self.class_name = function.top_level_parent_name + # Cache for _is_target_call type decisions + self._function_name = function.function_name def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: # Add timeout decorator for unittest test classes if needed @@ -397,30 +399,37 @@ def _process_test_function( return node def _instrument_statement(self, stmt: ast.stmt, _node_name: str) -> tuple[ast.stmt, bool]: - for node in ast.walk(stmt): + # Optimize ast.walk: Stop traversal immediately after matching Await/Call/target + stack = [stmt] + while stack: + node = stack.pop() + # The following 'isinstance' is fast, so keep order if ( isinstance(node, ast.Await) and isinstance(node.value, ast.Call) and self._is_target_call(node.value) and self._call_in_positions(node.value) ): - # Check if this call is in one of our target positions return stmt, True # Return original statement but signal we added env var - + # Only traverse children if this isn't Await with target (otherwise redundant) + for child in ast.iter_child_nodes(node): + stack.append(child) return stmt, False def _is_target_call(self, call_node: ast.Call) -> bool: """Check if this call node is calling our target async function.""" - if isinstance(call_node.func, ast.Name): - return call_node.func.id == self.function_object.function_name - if isinstance(call_node.func, ast.Attribute): - return call_node.func.attr == self.function_object.function_name + func = call_node.func + if isinstance(func, ast.Name): + # Early-exit style, as attribute hits are rare + return func.id == self._function_name + if isinstance(func, ast.Attribute): + return func.attr == self._function_name return False def _call_in_positions(self, call_node: ast.Call) -> bool: + # Inline hasattr test and use direct call to node_in_call_position (already optimal) if not hasattr(call_node, "lineno") or not hasattr(call_node, "col_offset"): return False - return node_in_call_position(call_node, self.call_positions)