Skip to content
341 changes: 152 additions & 189 deletions codeflash/code_utils/edit_generated_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import os
import re
from pathlib import Path
from textwrap import dedent
from typing import TYPE_CHECKING

import libcst as cst
from libcst import MetadataWrapper
from libcst.metadata import PositionProvider

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.time_utils import format_perf, format_time
Expand All @@ -16,230 +17,167 @@

if TYPE_CHECKING:
from codeflash.models.models import InvocationId
from codeflash.verification.verification_utils import TestConfig


def remove_functions_from_generated_tests(
generated_tests: GeneratedTestsList, test_functions_to_remove: list[str]
) -> GeneratedTestsList:
new_generated_tests = []
for generated_test in generated_tests.generated_tests:
for test_function in test_functions_to_remove:
function_pattern = re.compile(
rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?def\s+{re.escape(test_function)}\(.*?\):.*?(?=\ndef\s|$)",
re.DOTALL,
)

match = function_pattern.search(generated_test.generated_original_test_source)

if match is None or "@pytest.mark.parametrize" in match.group(0):
continue

generated_test.generated_original_test_source = function_pattern.sub(
"", generated_test.generated_original_test_source
)
class CommentMapper(ast.NodeVisitor):
def __init__(
self, test: GeneratedTests, original_runtimes: dict[str, int], optimized_runtimes: dict[str, int]
) -> None:
self.results: dict[int, str] = {}
self.test: GeneratedTests = test
self.original_runtimes = original_runtimes
self.optimized_runtimes = optimized_runtimes
self.abs_path = test.behavior_file_path.with_suffix("")
self.context_stack: list[str] = []

new_generated_tests.append(generated_test)
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
self.context_stack.append(node.name)
for inner_node in ast.walk(node):
if isinstance(inner_node, ast.FunctionDef):
self.visit_FunctionDef(inner_node)
self.context_stack.pop()
return node

def get_comment(self, match_key: str) -> str:
# calculate speedup and output comment
original_time = self.original_runtimes[match_key]
optimized_time = self.optimized_runtimes[match_key]
perf_gain = format_perf(
abs(performance_gain(original_runtime_ns=original_time, optimized_runtime_ns=optimized_time) * 100)
)
status = "slower" if optimized_time > original_time else "faster"
# Create the runtime comment
return f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"

def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
self.context_stack.append(node.name)
i = len(node.body) - 1
test_qualified_name = ".".join(self.context_stack)
key = test_qualified_name + "#" + str(self.abs_path)
while i >= 0:
line_node = node.body[i]
if isinstance(line_node, (ast.With, ast.For, ast.While, ast.If)):
j = len(line_node.body) - 1
while j >= 0:
compound_line_node: ast.stmt = line_node.body[j]
internal_node: ast.AST
for internal_node in ast.walk(compound_line_node):
if isinstance(internal_node, (ast.stmt, ast.Assign)):
inv_id = str(i) + "_" + str(j)
match_key = key + "#" + inv_id
if match_key in self.original_runtimes and match_key in self.optimized_runtimes:
self.results[internal_node.lineno] = self.get_comment(match_key)
j -= 1
else:
inv_id = str(i)
match_key = key + "#" + inv_id
if match_key in self.original_runtimes and match_key in self.optimized_runtimes:
self.results[line_node.lineno] = self.get_comment(match_key)
i -= 1
self.context_stack.pop()
return node

return GeneratedTestsList(generated_tests=new_generated_tests)

def get_fn_call_linenos(
test: GeneratedTests, original_runtimes: dict[str, int], optimized_runtimes: dict[str, int]
) -> dict[int, str]:
line_comment_ast_mapper = CommentMapper(test, original_runtimes, optimized_runtimes)
source_code = test.generated_original_test_source
tree = ast.parse(source_code)
line_comment_ast_mapper.visit(tree)
return line_comment_ast_mapper.results

class CfoVisitor(ast.NodeVisitor):
"""AST visitor that finds all assignments to a variable named 'codeflash_output'.

and reports their location relative to the function they're in.
"""
class CommentAdder(cst.CSTTransformer):
"""Transformer that adds comments to specified lines."""

def __init__(self, function_name: str, source_code: str) -> None:
self.source_lines = source_code.splitlines()
self.name = function_name
self.results: list[int] = [] # map actual line number to line number in ast
# Declare metadata dependencies
METADATA_DEPENDENCIES = (PositionProvider,)

def visit_Call(self, node): # type: ignore[no-untyped-def] # noqa: ANN201, ANN001
"""Detect fn calls."""
func_name = self._get_called_func_name(node.func) # type: ignore[no-untyped-call]
if func_name == self.name:
self.results.append(node.lineno - 1)
self.generic_visit(node)
def __init__(self, line_to_comments: dict[int, str]) -> None:
"""Initialize the transformer with target line numbers.

def _get_called_func_name(self, node): # type: ignore[no-untyped-def] # noqa: ANN001, ANN202
"""Return name of called fn."""
if isinstance(node, ast.Name):
return node.id
if isinstance(node, ast.Attribute):
return node.attr
return None
Args:
line_to_comments: Mapping of line numbers (1-indexed) to comments

"""
self.line_to_comments = line_to_comments
super().__init__()

def find_codeflash_output_assignments(function_name: str, source_code: str) -> list[int]:
tree = ast.parse(source_code)
visitor = CfoVisitor(function_name, source_code)
visitor.visit(tree)
return visitor.results
def leave_SimpleStatementLine(
self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine
) -> cst.SimpleStatementLine:
"""Add comment to simple statement lines."""
pos = self.get_metadata(PositionProvider, original_node)

if pos and pos.start.line in self.line_to_comments:
# Create a comment with trailing whitespace
comment = cst.TrailingWhitespace(
whitespace=cst.SimpleWhitespace(" "), comment=cst.Comment(self.line_to_comments[pos.start.line])
)

class Finder(cst.CSTVisitor):
def __init__(self, name: str) -> None:
super().__init__()
self.found = False
self.name = name
# Update the trailing whitespace of the line itself
return updated_node.with_changes(trailing_whitespace=comment)

def visit_Call(self, call_node) -> None: # type: ignore[no-untyped-def] # noqa : ANN001
func_expr = call_node.func
if isinstance(func_expr, cst.Name):
if func_expr.value == self.name:
self.found = True
elif isinstance(func_expr, cst.Attribute): # noqa : SIM102
if func_expr.attr.value == self.name:
self.found = True
return updated_node

def leave_SimpleStatementSuite(
self, original_node: cst.SimpleStatementSuite, updated_node: cst.SimpleStatementSuite
) -> cst.SimpleStatementSuite:
"""Add comment to simple statement suites (e.g., after if/for/while)."""
pos = self.get_metadata(PositionProvider, original_node)

# TODO: reduce for loops to one
class RuntimeCommentTransformer(cst.CSTTransformer):
def __init__(
self,
qualified_name: str,
module: cst.Module,
test: GeneratedTests,
tests_root: Path,
original_runtimes: dict[InvocationId, list[int]],
optimized_runtimes: dict[InvocationId, list[int]],
) -> None:
super().__init__()
self.test = test
self.context_stack: list[str] = []
self.tests_root = tests_root
self.module = module
self.cfo_locs: list[int] = []
self.cfo_idx_loc_to_look_at: int = -1
self.name = qualified_name.split(".")[-1]
self.original_runtimes = original_runtimes
self.optimized_runtimes = optimized_runtimes

def visit_ClassDef(self, node: cst.ClassDef) -> None:
# Track when we enter a class
self.context_stack.append(node.name.value)
if pos and pos.start.line in self.line_to_comments:
# Create a comment with trailing whitespace
comment = cst.TrailingWhitespace(
whitespace=cst.SimpleWhitespace(" "), comment=cst.Comment(self.line_to_comments[pos.start.line])
)

def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
# Pop the context when we leave a class
self.context_stack.pop()
return updated_node
# Update the trailing whitespace of the suite
return updated_node.with_changes(trailing_whitespace=comment)

def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
# convert function body to ast normalized string and find occurrences of codeflash_output
body_code = dedent(self.module.code_for_node(node.body))
normalized_body_code = ast.unparse(ast.parse(body_code))
self.cfo_locs = sorted(
find_codeflash_output_assignments(self.name, normalized_body_code)
) # sorted in order we will encounter them
self.cfo_idx_loc_to_look_at = -1
self.context_stack.append(node.name.value)

def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
# Pop the context when we leave a function
self.context_stack.pop()
return updated_node

def leave_SimpleStatementLine(
self,
original_node: cst.SimpleStatementLine, # noqa: ARG002
updated_node: cst.SimpleStatementLine,
) -> cst.SimpleStatementLine:
# Check if this statement line contains a call to self.name
if self._contains_myfunc_call(updated_node): # type: ignore[no-untyped-call]
# Find matching test cases by looking for this test function name in the test results
self.cfo_idx_loc_to_look_at += 1
matching_original_times = []
matching_optimized_times = []
# TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name + invocationid
for invocation_id, runtimes in self.original_runtimes.items():
# get position here and match in if condition
qualified_name = (
invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator]
if invocation_id.test_class_name
else invocation_id.test_function_name
)
abs_path = Path(invocation_id.test_module_path.replace(".", os.sep)).with_suffix(".py").resolve()
if (
qualified_name == ".".join(self.context_stack)
and abs_path in [self.test.behavior_file_path, self.test.perf_file_path]
and int(invocation_id.iteration_id.split("_")[0]) == self.cfo_locs[self.cfo_idx_loc_to_look_at] # type:ignore[union-attr]
):
matching_original_times.extend(runtimes)

for invocation_id, runtimes in self.optimized_runtimes.items():
# get position here and match in if condition
qualified_name = (
invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator]
if invocation_id.test_class_name
else invocation_id.test_function_name
)
abs_path = Path(invocation_id.test_module_path.replace(".", os.sep)).with_suffix(".py").resolve()
if (
qualified_name == ".".join(self.context_stack)
and abs_path in [self.test.behavior_file_path, self.test.perf_file_path]
and int(invocation_id.iteration_id.split("_")[0]) == self.cfo_locs[self.cfo_idx_loc_to_look_at] # type:ignore[union-attr]
):
matching_optimized_times.extend(runtimes)

if matching_original_times and matching_optimized_times:
original_time = min(matching_original_times)
optimized_time = min(matching_optimized_times)
if original_time != 0 and optimized_time != 0:
perf_gain = format_perf(
abs(
performance_gain(original_runtime_ns=original_time, optimized_runtime_ns=optimized_time)
* 100
)
)
status = "slower" if optimized_time > original_time else "faster"
# Create the runtime comment
comment_text = (
f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"
)
return updated_node.with_changes(
trailing_whitespace=cst.TrailingWhitespace(
whitespace=cst.SimpleWhitespace(" "),
comment=cst.Comment(comment_text),
newline=updated_node.trailing_whitespace.newline,
)
)
return updated_node

def _contains_myfunc_call(self, node): # type: ignore[no-untyped-def] # noqa : ANN202, ANN001
"""Recursively search for any Call node in the statement whose function is named self.name (including obj.myfunc)."""
finder = Finder(self.name)
node.visit(finder)
return finder.found
def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]]) -> dict[str, int]:
unique_inv_ids: dict[str, int] = {}
for inv_id, runtimes in inv_id_runtimes.items():
test_qualified_name = (
inv_id.test_class_name + "." + inv_id.test_function_name # type: ignore[operator]
if inv_id.test_class_name
else inv_id.test_function_name
)
abs_path = str(Path(inv_id.test_module_path.replace(".", os.sep)).with_suffix(".py").resolve().with_suffix(""))
if "__unit_test_" not in abs_path:
continue
key = test_qualified_name + "#" + abs_path # type: ignore[operator]
parts = inv_id.iteration_id.split("_").__len__() # type: ignore[union-attr]
cur_invid = inv_id.iteration_id.split("_")[0] if parts < 3 else "_".join(inv_id.iteration_id.split("_")[:-1]) # type: ignore[union-attr]
match_key = key + "#" + cur_invid
if match_key not in unique_inv_ids:
unique_inv_ids[match_key] = 0
unique_inv_ids[match_key] += min(runtimes)
return unique_inv_ids


def add_runtime_comments_to_generated_tests(
qualified_name: str,
test_cfg: TestConfig,
generated_tests: GeneratedTestsList,
original_runtimes: dict[InvocationId, list[int]],
optimized_runtimes: dict[InvocationId, list[int]],
) -> GeneratedTestsList:
"""Add runtime performance comments to function calls in generated tests."""
tests_root = test_cfg.tests_root

original_runtimes_dict = unique_inv_id(original_runtimes)
optimized_runtimes_dict = unique_inv_id(optimized_runtimes)
# Process each generated test
modified_tests = []
for test in generated_tests.generated_tests:
try:
# Parse the test source code
tree = cst.parse_module(test.generated_original_test_source)
# Transform the tree to add runtime comments
# qualified_name: str, module: cst.Module, test: GeneratedTests, tests_root: Path
transformer = RuntimeCommentTransformer(
qualified_name, tree, test, tests_root, original_runtimes, optimized_runtimes
)
modified_tree = tree.visit(transformer)

# Convert back to source code
wrapper = MetadataWrapper(tree)
line_to_comments = get_fn_call_linenos(test, original_runtimes_dict, optimized_runtimes_dict)
comment_adder = CommentAdder(line_to_comments)
modified_tree = wrapper.visit(comment_adder)
modified_source = modified_tree.code

# Create a new GeneratedTests object with the modified source
modified_test = GeneratedTests(
generated_original_test_source=modified_source,
instrumented_behavior_test_source=test.instrumented_behavior_test_source,
Expand All @@ -254,3 +192,28 @@ def add_runtime_comments_to_generated_tests(
modified_tests.append(test)

return GeneratedTestsList(generated_tests=modified_tests)


def remove_functions_from_generated_tests(
generated_tests: GeneratedTestsList, test_functions_to_remove: list[str]
) -> GeneratedTestsList:
new_generated_tests = []
for generated_test in generated_tests.generated_tests:
for test_function in test_functions_to_remove:
function_pattern = re.compile(
rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?def\s+{re.escape(test_function)}\(.*?\):.*?(?=\ndef\s|$)",
re.DOTALL,
)

match = function_pattern.search(generated_test.generated_original_test_source)

if match is None or "@pytest.mark.parametrize" in match.group(0):
continue

generated_test.generated_original_test_source = function_pattern.sub(
"", generated_test.generated_original_test_source
)

new_generated_tests.append(generated_test)

return GeneratedTestsList(generated_tests=new_generated_tests)
Loading
Loading