Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ def __init__(
self.async_call_counter: dict[str, int] = {}
if len(function.parents) == 1 and function.parents[0].type == "ClassDef":
self.class_name = function.top_level_parent_name
# Cache for _is_target_call type decisions
self._function_name = function.function_name

def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
# Add timeout decorator for unittest test classes if needed
Expand Down Expand Up @@ -397,30 +399,37 @@ def _process_test_function(
return node

def _instrument_statement(self, stmt: ast.stmt, _node_name: str) -> tuple[ast.stmt, bool]:
for node in ast.walk(stmt):
# Optimize ast.walk: Stop traversal immediately after matching Await/Call/target
stack = [stmt]
while stack:
node = stack.pop()
# The following 'isinstance' is fast, so keep order
if (
isinstance(node, ast.Await)
and isinstance(node.value, ast.Call)
and self._is_target_call(node.value)
and self._call_in_positions(node.value)
):
# Check if this call is in one of our target positions
return stmt, True # Return original statement but signal we added env var

# Only traverse children if this isn't Await with target (otherwise redundant)
for child in ast.iter_child_nodes(node):
stack.append(child)
return stmt, False

def _is_target_call(self, call_node: ast.Call) -> bool:
"""Check if this call node is calling our target async function."""
if isinstance(call_node.func, ast.Name):
return call_node.func.id == self.function_object.function_name
if isinstance(call_node.func, ast.Attribute):
return call_node.func.attr == self.function_object.function_name
func = call_node.func
if isinstance(func, ast.Name):
# Early-exit style, as attribute hits are rare
return func.id == self._function_name
if isinstance(func, ast.Attribute):
return func.attr == self._function_name
return False

def _call_in_positions(self, call_node: ast.Call) -> bool:
# Inline hasattr test and use direct call to node_in_call_position (already optimal)
if not hasattr(call_node, "lineno") or not hasattr(call_node, "col_offset"):
return False

return node_in_call_position(call_node, self.call_positions)


Expand Down
Loading