Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 84 additions & 99 deletions codeflash/code_utils/edit_generated_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.time_utils import format_perf, format_time
from codeflash.models.models import GeneratedTests, GeneratedTestsList
from codeflash.models.models import (GeneratedTests, GeneratedTestsList,
InvocationId)
from codeflash.result.critic import performance_gain
from codeflash.verification.verification_utils import TestConfig

if TYPE_CHECKING:
from codeflash.models.models import InvocationId
Expand Down Expand Up @@ -90,7 +92,35 @@ def add_runtime_comments_to_generated_tests(
module_root = test_cfg.project_root_path
rel_tests_root = tests_root.relative_to(module_root)

# TODO: reduce for loops to one
# ---- Preindex invocation results for O(1) matching -------
# (rel_path, qualified_name, cfo_loc) -> list[runtimes]
def _make_index(invocations):
index = {}
for invocation_id, runtimes in invocations.items():
test_class = invocation_id.test_class_name
test_func = invocation_id.test_function_name
q_name = f"{test_class}.{test_func}" if test_class else test_func
rel_path = Path(invocation_id.test_module_path.replace(".", os.sep)).with_suffix(".py")
# Defensive: sometimes path processing can fail, fallback to string
try:
rel_path = rel_path.relative_to(rel_tests_root)
except Exception:
rel_path = str(rel_path)
# Get CFO location integer
try:
cfo_loc = int(invocation_id.iteration_id.split("_")[0])
except Exception:
cfo_loc = None
key = (str(rel_path), q_name, cfo_loc)
if key not in index:
index[key] = []
index[key].extend(runtimes)
return index

orig_index = _make_index(original_runtimes)
opt_index = _make_index(optimized_runtimes)

# Optimized fast CST visitor base
class RuntimeCommentTransformer(cst.CSTTransformer):
def __init__(
self, qualified_name: str, module: cst.Module, test: GeneratedTests, tests_root: Path, rel_tests_root: Path
Expand All @@ -104,104 +134,66 @@ def __init__(
self.cfo_locs: list[int] = []
self.cfo_idx_loc_to_look_at: int = -1
self.name = qualified_name.split(".")[-1]
# Precompute test-local file relative paths for efficiency
self.test_rel_behavior = str(test.behavior_file_path.relative_to(tests_root))
self.test_rel_perf = str(test.perf_file_path.relative_to(tests_root))

def visit_ClassDef(self, node: cst.ClassDef) -> None:
# Track when we enter a class
self.context_stack.append(node.name.value)

def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
# Pop the context when we leave a class
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
self.context_stack.pop()
return updated_node

def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
# convert function body to ast normalized string and find occurrences of codeflash_output
# This could be optimized further if you access CFO assignments via CST
body_code = dedent(self.module.code_for_node(node.body))
normalized_body_code = ast.unparse(ast.parse(body_code))
self.cfo_locs = sorted(
find_codeflash_output_assignments(qualified_name, normalized_body_code)
) # sorted in order we will encounter them
self.cfo_locs = sorted(find_codeflash_output_assignments(qualified_name, normalized_body_code))
self.cfo_idx_loc_to_look_at = -1
self.context_stack.append(node.name.value)

def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
# Pop the context when we leave a function
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
self.context_stack.pop()
return updated_node

def leave_SimpleStatementLine(
self,
original_node: cst.SimpleStatementLine, # noqa: ARG002
updated_node: cst.SimpleStatementLine,
self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine
) -> cst.SimpleStatementLine:
# Check if this statement line contains a call to self.name
if self._contains_myfunc_call(updated_node): # type: ignore[no-untyped-call]
# Find matching test cases by looking for this test function name in the test results
# Fast skip before deep call tree walk by screening for Name nodes
if self._contains_myfunc_call(updated_node):
self.cfo_idx_loc_to_look_at += 1
matching_original_times = []
matching_optimized_times = []
# TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name + invocationid
for invocation_id, runtimes in original_runtimes.items():
# get position here and match in if condition
qualified_name = (
invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator]
if invocation_id.test_class_name
else invocation_id.test_function_name
)
rel_path = (
Path(invocation_id.test_module_path.replace(".", os.sep))
.with_suffix(".py")
.relative_to(self.rel_tests_root)
)
if (
qualified_name == ".".join(self.context_stack)
and rel_path
in [
self.test.behavior_file_path.relative_to(self.tests_root),
self.test.perf_file_path.relative_to(self.tests_root),
]
and int(invocation_id.iteration_id.split("_")[0]) == self.cfo_locs[self.cfo_idx_loc_to_look_at] # type:ignore[union-attr]
):
matching_original_times.extend(runtimes)

for invocation_id, runtimes in optimized_runtimes.items():
# get position here and match in if condition
qualified_name = (
invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator]
if invocation_id.test_class_name
else invocation_id.test_function_name
)
rel_path = (
Path(invocation_id.test_module_path.replace(".", os.sep))
.with_suffix(".py")
.relative_to(self.rel_tests_root)
)
if (
qualified_name == ".".join(self.context_stack)
and rel_path
in [
self.test.behavior_file_path.relative_to(self.tests_root),
self.test.perf_file_path.relative_to(self.tests_root),
]
and int(invocation_id.iteration_id.split("_")[0]) == self.cfo_locs[self.cfo_idx_loc_to_look_at] # type:ignore[union-attr]
):
matching_optimized_times.extend(runtimes)

if matching_original_times and matching_optimized_times:
original_time = min(matching_original_times)
optimized_time = min(matching_optimized_times)
if self.cfo_idx_loc_to_look_at >= len(self.cfo_locs):
return updated_node # Defensive, should never happen

cfo_loc = self.cfo_locs[self.cfo_idx_loc_to_look_at]

qualified_name_chain = ".".join(self.context_stack)
# Try both behavior and perf as possible locations; both are strings
possible_paths = {self.test_rel_behavior, self.test_rel_perf}

# Form index key(s)
matching_original = []
matching_optimized = []

for rel_path_str in possible_paths:
key = (rel_path_str, qualified_name_chain, cfo_loc)
if key in orig_index:
matching_original.extend(orig_index[key])
if key in opt_index:
matching_optimized.extend(opt_index[key])
if matching_original and matching_optimized:
original_time = min(matching_original)
optimized_time = min(matching_optimized)
if original_time != 0 and optimized_time != 0:
perf_gain = format_perf(
perf_gain_str = format_perf(
abs(
performance_gain(original_runtime_ns=original_time, optimized_runtime_ns=optimized_time)
* 100
)
)
status = "slower" if optimized_time > original_time else "faster"
# Create the runtime comment
comment_text = (
f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"
)
comment_text = f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain_str}% {status})"
return updated_node.with_changes(
trailing_whitespace=cst.TrailingWhitespace(
whitespace=cst.SimpleWhitespace(" "),
Expand All @@ -211,43 +203,37 @@ def leave_SimpleStatementLine(
)
return updated_node

def _contains_myfunc_call(self, node): # type: ignore[no-untyped-def] # noqa : ANN202, ANN001
def _contains_myfunc_call(self, node):
"""Recursively search for any Call node in the statement whose function is named self.name (including obj.myfunc)."""

# IMPORTANT micro-optimization: early abort using an exception
class Found(Exception):
pass

class Finder(cst.CSTVisitor):
def __init__(self, name: str) -> None:
super().__init__()
self.found = False
def __init__(self, name):
self.name = name

def visit_Call(self, call_node) -> None: # type: ignore[no-untyped-def] # noqa : ANN001
def visit_Call(self, call_node):
func_expr = call_node.func
if isinstance(func_expr, cst.Name):
if func_expr.value == self.name:
self.found = True
elif isinstance(func_expr, cst.Attribute): # noqa : SIM102
if func_expr.attr.value == self.name:
self.found = True

finder = Finder(self.name)
node.visit(finder)
return finder.found

# Process each generated test
if (isinstance(func_expr, cst.Name) and func_expr.value == self.name) or (
isinstance(func_expr, cst.Attribute) and func_expr.attr.value == self.name
):
raise Found

try:
node.visit(Finder(self.name))
except Found:
return True
return False

modified_tests = []
for test in generated_tests.generated_tests:
try:
# Parse the test source code
tree = cst.parse_module(test.generated_original_test_source)
# Transform the tree to add runtime comments
# qualified_name: str, module: cst.Module, test: GeneratedTests, tests_root: Path, rel_tests_root: Path
transformer = RuntimeCommentTransformer(qualified_name, tree, test, tests_root, rel_tests_root)
modified_tree = tree.visit(transformer)

# Convert back to source code
modified_source = modified_tree.code

# Create a new GeneratedTests object with the modified source
modified_test = GeneratedTests(
generated_original_test_source=modified_source,
instrumented_behavior_test_source=test.instrumented_behavior_test_source,
Expand All @@ -257,7 +243,6 @@ def visit_Call(self, call_node) -> None: # type: ignore[no-untyped-def] # noqa
)
modified_tests.append(modified_test)
except Exception as e:
# If parsing fails, keep the original test
logger.debug(f"Failed to add runtime comments to test: {e}")
modified_tests.append(test)

Expand Down
Loading