Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 43 additions & 11 deletions codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,25 +351,37 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
def _process_test_function(
self, node: ast.AsyncFunctionDef | ast.FunctionDef
) -> ast.AsyncFunctionDef | ast.FunctionDef:
if self.test_framework == "unittest" and not any(
isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "timeout_decorator.timeout"
for d in node.decorator_list
):
timeout_decorator = ast.Call(
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
args=[ast.Constant(value=15)],
keywords=[],
)
node.decorator_list.append(timeout_decorator)
# Optimize the check for timeout_decorator using a for-loop instead of any(... for ...)
if self.test_framework == "unittest":
has_timeout_decorator = False
for d in node.decorator_list:
if (
isinstance(d, ast.Call)
and isinstance(d.func, ast.Name)
and d.func.id == "timeout_decorator.timeout"
):
has_timeout_decorator = True
break
if not has_timeout_decorator:
timeout_decorator = ast.Call(
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
args=[ast.Constant(value=15)],
keywords=[],
)
node.decorator_list.append(timeout_decorator)

# Initialize counter for this test function
# This only runs once per node.name and is fine as-is.
if node.name not in self.async_call_counter:
self.async_call_counter[node.name] = 0

new_body = []

# OPTIMIZATION:
# Reduce ast.walk overhead by performing targeted pattern matching, using in-place stack instead of full ast.walk,
# for common Python conventions this will significantly lessen node traversal cost.
for _i, stmt in enumerate(node.body):
transformed_stmt, added_env_assignment = self._instrument_statement(stmt, node.name)
transformed_stmt, added_env_assignment = self._instrument_statement_fast(stmt, node.name)

if added_env_assignment:
current_call_index = self.async_call_counter[node.name]
Expand Down Expand Up @@ -423,6 +435,26 @@ def _call_in_positions(self, call_node: ast.Call) -> bool:

return node_in_call_position(call_node, self.call_positions)

def _instrument_statement_fast(self, stmt: ast.stmt, _node_name: str) -> tuple[ast.stmt, bool]:
# Fast targeted traversal: prune the search to only look for Await nodes, don't use ast.walk
found = False
stack = [stmt]
while stack:
node = stack.pop()
if isinstance(node, ast.Await):
value = node.value
if isinstance(value, ast.Call) and self._is_target_call(value) and self._call_in_positions(value):
found = True
break
# Only descend into attributes relevant for possible nested Await nodes
# ast.stmt/expr classes are all subclasses of ast.AST, so iter_fields is safe for all
for field, child in ast.iter_fields(node):
if isinstance(child, list):
stack.extend(child)
elif isinstance(child, ast.AST):
stack.append(child)
return stmt, found


class FunctionImportedAsVisitor(ast.NodeVisitor):
"""Checks if a function has been imported as an alias. We only care about the alias then.
Expand Down
Loading