Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 32 additions & 17 deletions codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,29 +351,40 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
def _process_test_function(
self, node: ast.AsyncFunctionDef | ast.FunctionDef
) -> ast.AsyncFunctionDef | ast.FunctionDef:
if self.test_framework == "unittest" and not any(
isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "timeout_decorator.timeout"
for d in node.decorator_list
):
timeout_decorator = ast.Call(
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
args=[ast.Constant(value=15)],
keywords=[],
)
node.decorator_list.append(timeout_decorator)
# Cache the test_framework check and decorator_list reference for better performance
if self.test_framework == "unittest":
decorlist = node.decorator_list
# Use a generator expression with a break sentinel for efficiency
timeout_found = False
for d in decorlist:
if (
isinstance(d, ast.Call)
and isinstance(d.func, ast.Name)
and d.func.id == "timeout_decorator.timeout"
):
timeout_found = True
break
if not timeout_found:
timeout_decorator = ast.Call(
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
args=[ast.Constant(value=15)],
keywords=[],
)
decorlist.append(timeout_decorator)

# Initialize counter for this test function
if node.name not in self.async_call_counter:
self.async_call_counter[node.name] = 0

new_body = []
async_call_counter = self.async_call_counter # Local binding for minor speedup

for _i, stmt in enumerate(node.body):
transformed_stmt, added_env_assignment = self._instrument_statement(stmt, node.name)

if added_env_assignment:
current_call_index = self.async_call_counter[node.name]
self.async_call_counter[node.name] += 1
current_call_index = async_call_counter[node.name]
async_call_counter[node.name] += 1

env_assignment = ast.Assign(
targets=[
Expand All @@ -386,7 +397,7 @@ def _process_test_function(
)
],
value=ast.Constant(value=f"{current_call_index}"),
lineno=stmt.lineno if hasattr(stmt, "lineno") else 1,
lineno=getattr(stmt, "lineno", 1), # More efficient than hasattr
)
new_body.append(env_assignment)
self.did_instrument = True
Expand All @@ -397,16 +408,20 @@ def _process_test_function(
return node

def _instrument_statement(self, stmt: ast.stmt, _node_name: str) -> tuple[ast.stmt, bool]:
for node in ast.walk(stmt):
# Instead of ast.walk, use a manual stack to avoid unnecessary generator overhead
stack = [stmt]
while stack:
node = stack.pop()
if (
isinstance(node, ast.Await)
and isinstance(node.value, ast.Call)
and self._is_target_call(node.value)
and self._call_in_positions(node.value)
):
# Check if this call is in one of our target positions
return stmt, True # Return original statement but signal we added env var

return stmt, True
# Instead of ast.walk's recursive generator, use direct child iteration
for child in ast.iter_child_nodes(node):
stack.append(child)
return stmt, False

def _is_target_call(self, call_node: ast.Call) -> bool:
Expand Down
Loading