Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 55 additions & 98 deletions codeflash/code_utils/edit_generated_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import re
from pathlib import Path
from textwrap import dedent
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING

import libcst as cst

Expand Down Expand Up @@ -50,89 +50,36 @@ class CfoVisitor(ast.NodeVisitor):
and reports their location relative to the function they're in.
"""

def __init__(self, source_code: str) -> None:
def __init__(self, qualifed_name: str, source_code: str) -> None:
self.source_lines = source_code.splitlines()
self.name = qualifed_name.split(".")[-1]
self.results: list[int] = [] # map actual line number to line number in ast

def _is_codeflash_output_target(self, target: Union[ast.expr, list]) -> bool: # type: ignore[type-arg]
"""Check if the assignment target is the variable 'codeflash_output'."""
if isinstance(target, ast.Name):
return target.id == "codeflash_output"
if isinstance(target, (ast.Tuple, ast.List)):
# Handle tuple/list unpacking: a, codeflash_output, b = values
return any(self._is_codeflash_output_target(elt) for elt in target.elts)
if isinstance(target, (ast.Subscript, ast.Attribute)):
# Not a simple variable assignment
return False
return False

def _record_assignment(self, node: ast.AST) -> None:
"""Record an assignment to codeflash_output."""
relative_line = node.lineno - 1 # type: ignore[attr-defined]
self.results.append(relative_line)

def visit_Assign(self, node: ast.Assign) -> None:
"""Visit assignment statements: codeflash_output = value."""
for target in node.targets:
if self._is_codeflash_output_target(target):
self._record_assignment(node)
break
def visit_Call(self, node): # noqa: ANN201, ANN001
"""Detect fn calls."""
func_name = self._get_called_func_name(node.func)
if func_name == self.name:
self.results.append(node.lineno - 1)
self.generic_visit(node)

def visit_AnnAssign(self, node: ast.AnnAssign) -> None:
"""Visit annotated assignments: codeflash_output: int = value."""
if self._is_codeflash_output_target(node.target):
self._record_assignment(node)
self.generic_visit(node)

def visit_AugAssign(self, node: ast.AugAssign) -> None:
"""Visit augmented assignments: codeflash_output += value."""
if self._is_codeflash_output_target(node.target):
self._record_assignment(node)
self.generic_visit(node)

def visit_NamedExpr(self, node: ast.NamedExpr) -> None:
"""Visit walrus operator: (codeflash_output := value)."""
if isinstance(node.target, ast.Name) and node.target.id == "codeflash_output":
self._record_assignment(node)
self.generic_visit(node)
def _get_called_func_name(self, node): # noqa: ANN001, ANN202
"""Return name of called fn."""
if isinstance(node, ast.Name):
return node.id
if isinstance(node, ast.Attribute):
return node.attr
return None

def visit_For(self, node: ast.For) -> None:
"""Visit for loops: for codeflash_output in iterable."""
if self._is_codeflash_output_target(node.target):
self._record_assignment(node)
self.generic_visit(node)

def visit_comprehension(self, node: ast.comprehension) -> None:
"""Visit comprehensions: [x for codeflash_output in iterable]."""
if self._is_codeflash_output_target(node.target):
# Comprehensions don't have line numbers, so we skip recording
pass
self.generic_visit(node)

def visit_With(self, node: ast.With) -> None:
"""Visit with statements: with expr as codeflash_output."""
for item in node.items:
if item.optional_vars and self._is_codeflash_output_target(item.optional_vars):
self._record_assignment(node)
break
self.generic_visit(node)

def visit_ExceptHandler(self, node: ast.ExceptHandler) -> None:
"""Visit except handlers: except Exception as codeflash_output."""
if node.name == "codeflash_output":
self._record_assignment(node)
self.generic_visit(node)


def find_codeflash_output_assignments(source_code: str) -> list[int]:
def find_codeflash_output_assignments(qualifed_name: str, source_code: str) -> list[int]:
tree = ast.parse(source_code)
visitor = CfoVisitor(source_code)
visitor = CfoVisitor(qualifed_name, source_code)
visitor.visit(tree)
return visitor.results


def add_runtime_comments_to_generated_tests(
qualifed_name: str,
test_cfg: TestConfig,
generated_tests: GeneratedTestsList,
original_runtimes: dict[InvocationId, list[int]],
Expand All @@ -145,7 +92,9 @@ def add_runtime_comments_to_generated_tests(

# TODO: reduce for loops to one
class RuntimeCommentTransformer(cst.CSTTransformer):
def __init__(self, module: cst.Module, test: GeneratedTests, tests_root: Path, rel_tests_root: Path) -> None:
def __init__(
self, qualified_name: str, module: cst.Module, test: GeneratedTests, tests_root: Path, rel_tests_root: Path
) -> None:
super().__init__()
self.test = test
self.context_stack: list[str] = []
Expand All @@ -154,6 +103,7 @@ def __init__(self, module: cst.Module, test: GeneratedTests, tests_root: Path, r
self.module = module
self.cfo_locs: list[int] = []
self.cfo_idx_loc_to_look_at: int = -1
self.name = qualified_name.split(".")[-1]

def visit_ClassDef(self, node: cst.ClassDef) -> None:
# Track when we enter a class
Expand All @@ -169,7 +119,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
body_code = dedent(self.module.code_for_node(node.body))
normalized_body_code = ast.unparse(ast.parse(body_code))
self.cfo_locs = sorted(
find_codeflash_output_assignments(normalized_body_code)
find_codeflash_output_assignments(qualifed_name, normalized_body_code)
) # sorted in order we will encounter them
self.cfo_idx_loc_to_look_at = -1
self.context_stack.append(node.name.value)
Expand All @@ -180,23 +130,10 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
return updated_node

def leave_SimpleStatementLine(
self,
original_node: cst.SimpleStatementLine, # noqa: ARG002
updated_node: cst.SimpleStatementLine,
self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine
) -> cst.SimpleStatementLine:
# Look for assignment statements that assign to codeflash_output
# Handle both single statements and multiple statements on one line
codeflash_assignment_found = False
for stmt in updated_node.body:
if isinstance(stmt, cst.Assign) and (
len(stmt.targets) == 1
and isinstance(stmt.targets[0].target, cst.Name)
and stmt.targets[0].target.value == "codeflash_output"
):
codeflash_assignment_found = True
break

if codeflash_assignment_found:
# Check if this statement line contains a call to self.name
if self._contains_myfunc_call(updated_node):
# Find matching test cases by looking for this test function name in the test results
self.cfo_idx_loc_to_look_at += 1
matching_original_times = []
Expand Down Expand Up @@ -263,17 +200,36 @@ def leave_SimpleStatementLine(
comment_text = (
f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"
)

# Add comment to the trailing whitespace
new_trailing_whitespace = cst.TrailingWhitespace(
whitespace=cst.SimpleWhitespace(" "),
comment=cst.Comment(comment_text),
newline=updated_node.trailing_whitespace.newline,
return updated_node.with_changes(
trailing_whitespace=cst.TrailingWhitespace(
whitespace=cst.SimpleWhitespace(" "),
comment=cst.Comment(comment_text),
newline=updated_node.trailing_whitespace.newline,
)
)
return updated_node

return updated_node.with_changes(trailing_whitespace=new_trailing_whitespace)
def _contains_myfunc_call(self, node):
"""Recursively search for any Call node in the statement whose function is named self.name (including obj.myfunc)."""

return updated_node
class Finder(cst.CSTVisitor):
def __init__(self, name: str):
super().__init__()
self.found = False
self.name = name

def visit_Call(self, call_node):
func_expr = call_node.func
if isinstance(func_expr, cst.Name):
if func_expr.value == self.name:
self.found = True
elif isinstance(func_expr, cst.Attribute):
if func_expr.attr.value == self.name:
self.found = True

finder = Finder(self.name)
node.visit(finder)
return finder.found

# Process each generated test
modified_tests = []
Expand All @@ -282,7 +238,8 @@ def leave_SimpleStatementLine(
# Parse the test source code
tree = cst.parse_module(test.generated_original_test_source)
# Transform the tree to add runtime comments
transformer = RuntimeCommentTransformer(tree, test, tests_root, rel_tests_root)
# qualified_name: str, module: cst.Module, test: GeneratedTests, tests_root: Path, rel_tests_root: Path
transformer = RuntimeCommentTransformer(qualifed_name, tree, test, tests_root, rel_tests_root)
modified_tree = tree.visit(transformer)

# Convert back to source code
Expand Down
9 changes: 7 additions & 2 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,15 +1012,20 @@ def find_and_process_best_optimization(
optimized_runtime_by_test = (
best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case()
)
qualifed_name = self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root)
# Add runtime comments to generated tests before creating the PR
generated_tests = add_runtime_comments_to_generated_tests(
self.test_cfg, generated_tests, original_runtime_by_test, optimized_runtime_by_test
qualifed_name,
self.test_cfg,
generated_tests,
original_runtime_by_test,
optimized_runtime_by_test,
)
generated_tests_str = "\n\n".join(
[test.generated_original_test_source for test in generated_tests.generated_tests]
)
existing_tests = existing_tests_source_for(
self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root),
qualifed_name,
function_to_all_tests,
test_cfg=self.test_cfg,
original_runtimes_all=original_runtime_by_test,
Expand Down
Loading
Loading