diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index db08f8afc..f34a72f81 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -351,29 +351,40 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: def _process_test_function( self, node: ast.AsyncFunctionDef | ast.FunctionDef ) -> ast.AsyncFunctionDef | ast.FunctionDef: - if self.test_framework == "unittest" and not any( - isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "timeout_decorator.timeout" - for d in node.decorator_list - ): - timeout_decorator = ast.Call( - func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()), - args=[ast.Constant(value=15)], - keywords=[], - ) - node.decorator_list.append(timeout_decorator) + # Cache the test_framework check and decorator_list reference for better performance + if self.test_framework == "unittest": + decorlist = node.decorator_list + # Use a generator expression with a break sentinel for efficiency + timeout_found = False + for d in decorlist: + if ( + isinstance(d, ast.Call) + and isinstance(d.func, ast.Name) + and d.func.id == "timeout_decorator.timeout" + ): + timeout_found = True + break + if not timeout_found: + timeout_decorator = ast.Call( + func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()), + args=[ast.Constant(value=15)], + keywords=[], + ) + decorlist.append(timeout_decorator) # Initialize counter for this test function if node.name not in self.async_call_counter: self.async_call_counter[node.name] = 0 new_body = [] + async_call_counter = self.async_call_counter # Local binding for minor speedup for _i, stmt in enumerate(node.body): transformed_stmt, added_env_assignment = self._instrument_statement(stmt, node.name) if added_env_assignment: - current_call_index = self.async_call_counter[node.name] - self.async_call_counter[node.name] += 1 + current_call_index = async_call_counter[node.name] + async_call_counter[node.name] += 1 env_assignment = ast.Assign( targets=[ @@ -386,7 +397,7 @@ def _process_test_function( ) ], value=ast.Constant(value=f"{current_call_index}"), - lineno=stmt.lineno if hasattr(stmt, "lineno") else 1, + lineno=getattr(stmt, "lineno", 1), # More efficient than hasattr ) new_body.append(env_assignment) self.did_instrument = True @@ -397,16 +408,20 @@ def _process_test_function( return node def _instrument_statement(self, stmt: ast.stmt, _node_name: str) -> tuple[ast.stmt, bool]: - for node in ast.walk(stmt): + # Instead of ast.walk, use a manual stack to avoid unnecessary generator overhead + stack = [stmt] + while stack: + node = stack.pop() if ( isinstance(node, ast.Await) and isinstance(node.value, ast.Call) and self._is_target_call(node.value) and self._call_in_positions(node.value) ): - # Check if this call is in one of our target positions - return stmt, True # Return original statement but signal we added env var - + return stmt, True + # Instead of ast.walk's recursive generator, use direct child iteration + for child in ast.iter_child_nodes(node): + stack.append(child) return stmt, False def _is_target_call(self, call_node: ast.Call) -> bool: