Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 40 additions & 11 deletions codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,25 +351,34 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
def _process_test_function(
self, node: ast.AsyncFunctionDef | ast.FunctionDef
) -> ast.AsyncFunctionDef | ast.FunctionDef:
if self.test_framework == "unittest" and not any(
isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "timeout_decorator.timeout"
for d in node.decorator_list
):
timeout_decorator = ast.Call(
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
args=[ast.Constant(value=15)],
keywords=[],
)
node.decorator_list.append(timeout_decorator)
# Optimize the search for decorator presence
if self.test_framework == "unittest":
found_timeout = False
for d in node.decorator_list:
# Avoid isinstance(d.func, ast.Name) if d is not ast.Call
if isinstance(d, ast.Call):
f = d.func
# Avoid attribute lookup if f is not ast.Name
if isinstance(f, ast.Name) and f.id == "timeout_decorator.timeout":
found_timeout = True
break
if not found_timeout:
timeout_decorator = ast.Call(
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
args=[ast.Constant(value=15)],
keywords=[],
)
node.decorator_list.append(timeout_decorator)

# Initialize counter for this test function
if node.name not in self.async_call_counter:
self.async_call_counter[node.name] = 0

new_body = []

# Optimize ast.walk calls inside _instrument_statement, by scanning only relevant nodes
for _i, stmt in enumerate(node.body):
transformed_stmt, added_env_assignment = self._instrument_statement(stmt, node.name)
transformed_stmt, added_env_assignment = self._optimized_instrument_statement(stmt)

if added_env_assignment:
current_call_index = self.async_call_counter[node.name]
Expand Down Expand Up @@ -423,6 +432,26 @@ def _call_in_positions(self, call_node: ast.Call) -> bool:

return node_in_call_position(call_node, self.call_positions)

# Optimized version: only walk child nodes for Await
def _optimized_instrument_statement(self, stmt: ast.stmt) -> tuple[ast.stmt, bool]:
# Stack-based DFS, manual for relevant Await nodes
stack = [stmt]
while stack:
node = stack.pop()
# Favor direct ast.Await detection
if isinstance(node, ast.Await):
val = node.value
if isinstance(val, ast.Call) and self._is_target_call(val) and self._call_in_positions(val):
return stmt, True
# Use _fields instead of ast.walk for less allocations
for fname in getattr(node, "_fields", ()):
child = getattr(node, fname, None)
if isinstance(child, list):
stack.extend(child)
elif isinstance(child, ast.AST):
stack.append(child)
return stmt, False


class FunctionImportedAsVisitor(ast.NodeVisitor):
"""Checks if a function has been imported as an alias. We only care about the alias then.
Expand Down
Loading