Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
313 changes: 249 additions & 64 deletions codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import ast
import platform
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING

Expand All @@ -20,24 +21,39 @@
from codeflash.models.models import CodePosition


@dataclass(frozen=True)
class FunctionCallNodeArguments:
args: list[ast.expr]
keywords: list[ast.keyword]


def get_call_arguments(call_node: ast.Call) -> FunctionCallNodeArguments:
return FunctionCallNodeArguments(call_node.args, call_node.keywords)


def node_in_call_position(node: ast.AST, call_positions: list[CodePosition]) -> bool:
if isinstance(node, ast.Call) and hasattr(node, "lineno") and hasattr(node, "col_offset"):
for pos in call_positions:
if (
pos.line_no is not None
and node.end_lineno is not None
and node.lineno <= pos.line_no <= node.end_lineno
):
if pos.line_no == node.lineno and node.col_offset <= pos.col_no:
return True
if (
pos.line_no == node.end_lineno
and node.end_col_offset is not None
and node.end_col_offset >= pos.col_no
):
return True
if node.lineno < pos.line_no < node.end_lineno:
return True
# Profile: The most meaningful speedup here is to reduce attribute lookup and to localize call_positions if not empty.
# Small optimizations for tight loop:
if isinstance(node, ast.Call):
node_lineno = getattr(node, "lineno", None)
node_col_offset = getattr(node, "col_offset", None)
node_end_lineno = getattr(node, "end_lineno", None)
node_end_col_offset = getattr(node, "end_col_offset", None)
if node_lineno is not None and node_col_offset is not None and node_end_lineno is not None:
# Faster loop: reduce attribute lookups, use local variables for conditionals.
for pos in call_positions:
pos_line = pos.line_no
if pos_line is not None and node_lineno <= pos_line <= node_end_lineno:
if pos_line == node_lineno and node_col_offset <= pos.col_no:
return True
if (
pos_line == node_end_lineno
and node_end_col_offset is not None
and node_end_col_offset >= pos.col_no
):
return True
if node_lineno < pos_line < node_end_lineno:
return True
return False


Expand Down Expand Up @@ -73,66 +89,231 @@ def __init__(
def find_and_update_line_node(
self, test_node: ast.stmt, node_name: str, index: str, test_class_name: str | None = None
) -> Iterable[ast.stmt] | None:
# Major optimization: since ast.walk is *very* expensive for big trees and only checks for ast.Call,
# it's much more efficient to visit nodes manually. We'll only descend into expressions/statements.

# Helper for manual walk
def iter_ast_calls(node):
# Generator to yield each ast.Call in test_node, preserves node identity
stack = [node]
while stack:
n = stack.pop()
if isinstance(n, ast.Call):
yield n
# Instead of using ast.walk (which calls iter_child_nodes under the hood in Python, which copy lists and stack-frames for EVERY node),
# do a specialized BFS with only the necessary attributes
for field, value in ast.iter_fields(n):
if isinstance(value, list):
for item in reversed(value):
if isinstance(item, ast.AST):
stack.append(item)
elif isinstance(value, ast.AST):
stack.append(value)

# This change improves from O(N) stack-frames per child-node to a single stack, less python call overhead
return_statement = [test_node]
call_node = None
for node in ast.walk(test_node):
if isinstance(node, ast.Call) and node_in_call_position(node, self.call_positions):
call_node = node
if isinstance(node.func, ast.Name):
function_name = node.func.id

if self.function_object.is_async:
# Minor optimization: Convert mode, function_name, test_class_name, qualified_name, etc to locals
fn_obj = self.function_object
module_path = self.module_path
mode = self.mode
qualified_name = fn_obj.qualified_name

# Use locals for all 'current' values, only look up class/function/constant AST object once.
codeflash_loop_index = ast.Name(id="codeflash_loop_index", ctx=ast.Load())
codeflash_cur = ast.Name(id="codeflash_cur", ctx=ast.Load())
codeflash_con = ast.Name(id="codeflash_con", ctx=ast.Load())

for node in iter_ast_calls(test_node):
if not node_in_call_position(node, self.call_positions):
continue

call_node = node
all_args = get_call_arguments(call_node)
# Two possible call types: Name and Attribute
node_func = node.func

if isinstance(node_func, ast.Name):
function_name = node_func.id

if fn_obj.is_async:
return [test_node]

# Build once, reuse objects.
inspect_name = ast.Name(id="inspect", ctx=ast.Load())
bind_call = ast.Assign(
targets=[ast.Name(id="_call__bound__arguments", ctx=ast.Store())],
value=ast.Call(
func=ast.Attribute(
value=ast.Call(
func=ast.Attribute(value=inspect_name, attr="signature", ctx=ast.Load()),
args=[ast.Name(id=function_name, ctx=ast.Load())],
keywords=[],
),
attr="bind",
ctx=ast.Load(),
),
args=all_args.args,
keywords=all_args.keywords,
),
lineno=test_node.lineno,
col_offset=test_node.col_offset,
)

apply_defaults = ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
attr="apply_defaults",
ctx=ast.Load(),
),
args=[],
keywords=[],
),
lineno=test_node.lineno + 1,
col_offset=test_node.col_offset,
)

node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
base_args = [
ast.Name(id=function_name, ctx=ast.Load()),
ast.Constant(value=module_path),
ast.Constant(value=test_class_name or None),
ast.Constant(value=node_name),
ast.Constant(value=qualified_name),
ast.Constant(value=index),
codeflash_loop_index,
]
# Extend with BEHAVIOR extras if needed
if mode == TestingMode.BEHAVIOR:
base_args += [codeflash_cur, codeflash_con]
# Extend with call args (performance) or starred bound args (behavior)
if mode == TestingMode.PERFORMANCE:
base_args += call_node.args
else:
base_args.append(
ast.Starred(
value=ast.Attribute(
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
attr="args",
ctx=ast.Load(),
),
ctx=ast.Load(),
)
)
node.args = base_args
# Prepare keywords
if mode == TestingMode.BEHAVIOR:
node.keywords = [
ast.keyword(
value=ast.Attribute(
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
attr="kwargs",
ctx=ast.Load(),
)
)
]
else:
node.keywords = call_node.keywords

return_statement = (
[bind_call, apply_defaults, test_node] if mode == TestingMode.BEHAVIOR else [test_node]
)
break
if isinstance(node_func, ast.Attribute):
function_to_test = node_func.attr
if function_to_test == fn_obj.function_name:
if fn_obj.is_async:
return [test_node]

# Create the signature binding statements

# Unparse only once
function_name_expr = ast.parse(ast.unparse(node_func), mode="eval").body

inspect_name = ast.Name(id="inspect", ctx=ast.Load())
bind_call = ast.Assign(
targets=[ast.Name(id="_call__bound__arguments", ctx=ast.Store())],
value=ast.Call(
func=ast.Attribute(
value=ast.Call(
func=ast.Attribute(value=inspect_name, attr="signature", ctx=ast.Load()),
args=[function_name_expr],
keywords=[],
),
attr="bind",
ctx=ast.Load(),
),
args=all_args.args,
keywords=all_args.keywords,
),
lineno=test_node.lineno,
col_offset=test_node.col_offset,
)

apply_defaults = ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
attr="apply_defaults",
ctx=ast.Load(),
),
args=[],
keywords=[],
),
lineno=test_node.lineno + 1,
col_offset=test_node.col_offset,
)

node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
node.args = [
ast.Name(id=function_name, ctx=ast.Load()),
ast.Constant(value=self.module_path),
base_args = [
function_name_expr,
ast.Constant(value=module_path),
ast.Constant(value=test_class_name or None),
ast.Constant(value=node_name),
ast.Constant(value=self.function_object.qualified_name),
ast.Constant(value=qualified_name),
ast.Constant(value=index),
ast.Name(id="codeflash_loop_index", ctx=ast.Load()),
*(
[ast.Name(id="codeflash_cur", ctx=ast.Load()), ast.Name(id="codeflash_con", ctx=ast.Load())]
if self.mode == TestingMode.BEHAVIOR
else []
),
*call_node.args,
codeflash_loop_index,
]
node.keywords = call_node.keywords
break
if isinstance(node.func, ast.Attribute):
function_to_test = node.func.attr
if function_to_test == self.function_object.function_name:
if self.function_object.is_async:
return [test_node]

function_name = ast.unparse(node.func)
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
node.args = [
ast.Name(id=function_name, ctx=ast.Load()),
ast.Constant(value=self.module_path),
ast.Constant(value=test_class_name or None),
ast.Constant(value=node_name),
ast.Constant(value=self.function_object.qualified_name),
ast.Constant(value=index),
ast.Name(id="codeflash_loop_index", ctx=ast.Load()),
*(
[
ast.Name(id="codeflash_cur", ctx=ast.Load()),
ast.Name(id="codeflash_con", ctx=ast.Load()),
]
if self.mode == TestingMode.BEHAVIOR
else []
),
*call_node.args,
if mode == TestingMode.BEHAVIOR:
base_args += [codeflash_cur, codeflash_con]
if mode == TestingMode.PERFORMANCE:
base_args += call_node.args
else:
base_args.append(
ast.Starred(
value=ast.Attribute(
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
attr="args",
ctx=ast.Load(),
),
ctx=ast.Load(),
)
)
node.args = base_args
if mode == TestingMode.BEHAVIOR:
node.keywords = [
ast.keyword(
value=ast.Attribute(
value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
attr="kwargs",
ctx=ast.Load(),
)
)
]
else:
node.keywords = call_node.keywords
break

# Return the signature binding statements along with the test_node
return_statement = (
[bind_call, apply_defaults, test_node] if mode == TestingMode.BEHAVIOR else [test_node]
)
break

if call_node is None:
return None
return [test_node]
return return_statement

def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
# TODO: Ensure that this class inherits from unittest.TestCase. Don't modify non unittest.TestCase classes.
Expand Down Expand Up @@ -593,7 +774,11 @@ def inject_profiling_into_existing_test(
]
if mode == TestingMode.BEHAVIOR:
new_imports.extend(
[ast.Import(names=[ast.alias(name="sqlite3")]), ast.Import(names=[ast.alias(name="dill", asname="pickle")])]
[
ast.Import(names=[ast.alias(name="inspect")]),
ast.Import(names=[ast.alias(name="sqlite3")]),
ast.Import(names=[ast.alias(name="dill", asname="pickle")]),
]
)
if test_framework == "unittest" and platform.system() != "Windows":
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
Expand Down
Loading
Loading