Skip to content

Commit be2ffea

Browse files
authored
Merge pull request #780 from codeflash-ai/codeflash/optimize-pr769-2025-09-27T02.50.03
⚡️ Speed up method `AsyncCallInstrumenter.visit_AsyncFunctionDef` by 123% in PR #769 (`clean-async-branch`)
2 parents e27c133 + 585249f commit be2ffea

File tree

1 file changed

+40
-11
lines changed

1 file changed

+40
-11
lines changed

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -351,25 +351,34 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
351351
def _process_test_function(
352352
self, node: ast.AsyncFunctionDef | ast.FunctionDef
353353
) -> ast.AsyncFunctionDef | ast.FunctionDef:
354-
if self.test_framework == "unittest" and not any(
355-
isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "timeout_decorator.timeout"
356-
for d in node.decorator_list
357-
):
358-
timeout_decorator = ast.Call(
359-
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
360-
args=[ast.Constant(value=15)],
361-
keywords=[],
362-
)
363-
node.decorator_list.append(timeout_decorator)
354+
# Optimize the search for decorator presence
355+
if self.test_framework == "unittest":
356+
found_timeout = False
357+
for d in node.decorator_list:
358+
# Avoid isinstance(d.func, ast.Name) if d is not ast.Call
359+
if isinstance(d, ast.Call):
360+
f = d.func
361+
# Avoid attribute lookup if f is not ast.Name
362+
if isinstance(f, ast.Name) and f.id == "timeout_decorator.timeout":
363+
found_timeout = True
364+
break
365+
if not found_timeout:
366+
timeout_decorator = ast.Call(
367+
func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()),
368+
args=[ast.Constant(value=15)],
369+
keywords=[],
370+
)
371+
node.decorator_list.append(timeout_decorator)
364372

365373
# Initialize counter for this test function
366374
if node.name not in self.async_call_counter:
367375
self.async_call_counter[node.name] = 0
368376

369377
new_body = []
370378

379+
# Optimize ast.walk calls inside _instrument_statement, by scanning only relevant nodes
371380
for _i, stmt in enumerate(node.body):
372-
transformed_stmt, added_env_assignment = self._instrument_statement(stmt, node.name)
381+
transformed_stmt, added_env_assignment = self._optimized_instrument_statement(stmt)
373382

374383
if added_env_assignment:
375384
current_call_index = self.async_call_counter[node.name]
@@ -423,6 +432,26 @@ def _call_in_positions(self, call_node: ast.Call) -> bool:
423432

424433
return node_in_call_position(call_node, self.call_positions)
425434

435+
# Optimized version: only walk child nodes for Await
436+
def _optimized_instrument_statement(self, stmt: ast.stmt) -> tuple[ast.stmt, bool]:
437+
# Stack-based DFS, manual for relevant Await nodes
438+
stack = [stmt]
439+
while stack:
440+
node = stack.pop()
441+
# Favor direct ast.Await detection
442+
if isinstance(node, ast.Await):
443+
val = node.value
444+
if isinstance(val, ast.Call) and self._is_target_call(val) and self._call_in_positions(val):
445+
return stmt, True
446+
# Use _fields instead of ast.walk for less allocations
447+
for fname in getattr(node, "_fields", ()):
448+
child = getattr(node, fname, None)
449+
if isinstance(child, list):
450+
stack.extend(child)
451+
elif isinstance(child, ast.AST):
452+
stack.append(child)
453+
return stmt, False
454+
426455

427456
class FunctionImportedAsVisitor(ast.NodeVisitor):
428457
"""Checks if a function has been imported as an alias. We only care about the alias then.

0 commit comments

Comments
 (0)