diff --git a/codeflash/code_utils/edit_generated_tests.py b/codeflash/code_utils/edit_generated_tests.py index 547dbc92b..35150e0da 100644 --- a/codeflash/code_utils/edit_generated_tests.py +++ b/codeflash/code_utils/edit_generated_tests.py @@ -1,14 +1,22 @@ +from __future__ import annotations + +import ast import os import re from pathlib import Path +from textwrap import dedent +from typing import TYPE_CHECKING, Union import libcst as cst from codeflash.cli_cmds.console import logger from codeflash.code_utils.time_utils import format_perf, format_time -from codeflash.models.models import GeneratedTests, GeneratedTestsList, InvocationId +from codeflash.models.models import GeneratedTests, GeneratedTestsList from codeflash.result.critic import performance_gain -from codeflash.verification.verification_utils import TestConfig + +if TYPE_CHECKING: + from codeflash.models.models import InvocationId + from codeflash.verification.verification_utils import TestConfig def remove_functions_from_generated_tests( @@ -36,6 +44,94 @@ def remove_functions_from_generated_tests( return GeneratedTestsList(generated_tests=new_generated_tests) +class CfoVisitor(ast.NodeVisitor): + """AST visitor that finds all assignments to a variable named 'codeflash_output'. + + and reports their location relative to the function they're in. + """ + + def __init__(self, source_code: str) -> None: + self.source_lines = source_code.splitlines() + self.results: list[int] = [] # map actual line number to line number in ast + + def _is_codeflash_output_target(self, target: Union[ast.expr, list]) -> bool: # type: ignore[type-arg] + """Check if the assignment target is the variable 'codeflash_output'.""" + if isinstance(target, ast.Name): + return target.id == "codeflash_output" + if isinstance(target, (ast.Tuple, ast.List)): + # Handle tuple/list unpacking: a, codeflash_output, b = values + return any(self._is_codeflash_output_target(elt) for elt in target.elts) + if isinstance(target, (ast.Subscript, ast.Attribute)): + # Not a simple variable assignment + return False + return False + + def _record_assignment(self, node: ast.AST) -> None: + """Record an assignment to codeflash_output.""" + relative_line = node.lineno - 1 # type: ignore[attr-defined] + self.results.append(relative_line) + + def visit_Assign(self, node: ast.Assign) -> None: + """Visit assignment statements: codeflash_output = value.""" + for target in node.targets: + if self._is_codeflash_output_target(target): + self._record_assignment(node) + break + self.generic_visit(node) + + def visit_AnnAssign(self, node: ast.AnnAssign) -> None: + """Visit annotated assignments: codeflash_output: int = value.""" + if self._is_codeflash_output_target(node.target): + self._record_assignment(node) + self.generic_visit(node) + + def visit_AugAssign(self, node: ast.AugAssign) -> None: + """Visit augmented assignments: codeflash_output += value.""" + if self._is_codeflash_output_target(node.target): + self._record_assignment(node) + self.generic_visit(node) + + def visit_NamedExpr(self, node: ast.NamedExpr) -> None: + """Visit walrus operator: (codeflash_output := value).""" + if isinstance(node.target, ast.Name) and node.target.id == "codeflash_output": + self._record_assignment(node) + self.generic_visit(node) + + def visit_For(self, node: ast.For) -> None: + """Visit for loops: for codeflash_output in iterable.""" + if self._is_codeflash_output_target(node.target): + self._record_assignment(node) + self.generic_visit(node) + + def visit_comprehension(self, node: ast.comprehension) -> None: + """Visit comprehensions: [x for codeflash_output in iterable].""" + if self._is_codeflash_output_target(node.target): + # Comprehensions don't have line numbers, so we skip recording + pass + self.generic_visit(node) + + def visit_With(self, node: ast.With) -> None: + """Visit with statements: with expr as codeflash_output.""" + for item in node.items: + if item.optional_vars and self._is_codeflash_output_target(item.optional_vars): + self._record_assignment(node) + break + self.generic_visit(node) + + def visit_ExceptHandler(self, node: ast.ExceptHandler) -> None: + """Visit except handlers: except Exception as codeflash_output.""" + if node.name == "codeflash_output": + self._record_assignment(node) + self.generic_visit(node) + + +def find_codeflash_output_assignments(source_code: str) -> list[int]: + tree = ast.parse(source_code) + visitor = CfoVisitor(source_code) + visitor.visit(tree) + return visitor.results + + def add_runtime_comments_to_generated_tests( test_cfg: TestConfig, generated_tests: GeneratedTestsList, @@ -49,11 +145,15 @@ def add_runtime_comments_to_generated_tests( # TODO: reduce for loops to one class RuntimeCommentTransformer(cst.CSTTransformer): - def __init__(self, test: GeneratedTests, tests_root: Path, rel_tests_root: Path) -> None: + def __init__(self, module: cst.Module, test: GeneratedTests, tests_root: Path, rel_tests_root: Path) -> None: + super().__init__() self.test = test self.context_stack: list[str] = [] self.tests_root = tests_root self.rel_tests_root = rel_tests_root + self.module = module + self.cfo_locs: list[int] = [] + self.cfo_idx_loc_to_look_at: int = -1 def visit_ClassDef(self, node: cst.ClassDef) -> None: # Track when we enter a class @@ -65,6 +165,13 @@ def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef return updated_node def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + # convert function body to ast normalized string and find occurrences of codeflash_output + body_code = dedent(self.module.code_for_node(node.body)) + normalized_body_code = ast.unparse(ast.parse(body_code)) + self.cfo_locs = sorted( + find_codeflash_output_assignments(normalized_body_code) + ) # sorted in order we will encounter them + self.cfo_idx_loc_to_look_at = -1 self.context_stack.append(node.name.value) def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002 @@ -91,10 +198,12 @@ def leave_SimpleStatementLine( if codeflash_assignment_found: # Find matching test cases by looking for this test function name in the test results + self.cfo_idx_loc_to_look_at += 1 matching_original_times = [] matching_optimized_times = [] - # TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name + # TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name + invocationid for invocation_id, runtimes in original_runtimes.items(): + # get position here and match in if condition qualified_name = ( invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator] if invocation_id.test_class_name @@ -105,13 +214,19 @@ def leave_SimpleStatementLine( .with_suffix(".py") .relative_to(self.rel_tests_root) ) - if qualified_name == ".".join(self.context_stack) and rel_path in [ - self.test.behavior_file_path.relative_to(self.tests_root), - self.test.perf_file_path.relative_to(self.tests_root), - ]: + if ( + qualified_name == ".".join(self.context_stack) + and rel_path + in [ + self.test.behavior_file_path.relative_to(self.tests_root), + self.test.perf_file_path.relative_to(self.tests_root), + ] + and int(invocation_id.iteration_id.split("_")[0]) == self.cfo_locs[self.cfo_idx_loc_to_look_at] # type:ignore[union-attr] + ): matching_original_times.extend(runtimes) for invocation_id, runtimes in optimized_runtimes.items(): + # get position here and match in if condition qualified_name = ( invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator] if invocation_id.test_class_name @@ -122,10 +237,15 @@ def leave_SimpleStatementLine( .with_suffix(".py") .relative_to(self.rel_tests_root) ) - if qualified_name == ".".join(self.context_stack) and rel_path in [ - self.test.behavior_file_path.relative_to(self.tests_root), - self.test.perf_file_path.relative_to(self.tests_root), - ]: + if ( + qualified_name == ".".join(self.context_stack) + and rel_path + in [ + self.test.behavior_file_path.relative_to(self.tests_root), + self.test.perf_file_path.relative_to(self.tests_root), + ] + and int(invocation_id.iteration_id.split("_")[0]) == self.cfo_locs[self.cfo_idx_loc_to_look_at] # type:ignore[union-attr] + ): matching_optimized_times.extend(runtimes) if matching_original_times and matching_optimized_times: @@ -161,9 +281,8 @@ def leave_SimpleStatementLine( try: # Parse the test source code tree = cst.parse_module(test.generated_original_test_source) - # Transform the tree to add runtime comments - transformer = RuntimeCommentTransformer(test, tests_root, rel_tests_root) + transformer = RuntimeCommentTransformer(tree, test, tests_root, rel_tests_root) modified_tree = tree.visit(transformer) # Convert back to source code diff --git a/tests/test_add_runtime_comments.py b/tests/test_add_runtime_comments.py index 66a77b0d0..71f1d7566 100644 --- a/tests/test_add_runtime_comments.py +++ b/tests/test_add_runtime_comments.py @@ -60,9 +60,6 @@ def test_basic_runtime_comment_addition(self, test_config): behavior_file_path=Path("/project/tests/test_module.py"), perf_file_path=Path("/project/tests/test_module_perf.py"), ) - """add_runtime_comments_to_generated_tests( - test_config, generated_tests, original_runtimes, optimized_runtimes - )""" generated_tests = GeneratedTestsList(generated_tests=[generated_test]) # Create test results @@ -70,8 +67,8 @@ def test_basic_runtime_comment_addition(self, test_config): optimized_test_results = TestResults() # Add test invocations with different runtimes - original_invocation = self.create_test_invocation("test_bubble_sort", 500_000) # 500μs - optimized_invocation = self.create_test_invocation("test_bubble_sort", 300_000) # 300μs + original_invocation = self.create_test_invocation("test_bubble_sort", 500_000, iteration_id='0') # 500μs + optimized_invocation = self.create_test_invocation("test_bubble_sort", 300_000, iteration_id='0') # 300μs original_test_results.add(original_invocation) optimized_test_results.add(optimized_invocation) @@ -114,11 +111,11 @@ def helper_function(): optimized_test_results = TestResults() # Add test invocations for both test functions - original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000)) - original_test_results.add(self.create_test_invocation("test_quick_sort", 800_000)) + original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000, iteration_id='0')) + original_test_results.add(self.create_test_invocation("test_quick_sort", 800_000, iteration_id='0')) - optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000)) - optimized_test_results.add(self.create_test_invocation("test_quick_sort", 600_000)) + optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000, iteration_id='0')) + optimized_test_results.add(self.create_test_invocation("test_quick_sort", 600_000, iteration_id='0')) original_runtimes = original_test_results.usable_runtime_data_by_test_case() optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() @@ -150,6 +147,7 @@ def test_different_time_formats(self, test_config): for original_time, optimized_time, expected_comment in test_cases: test_source = """def test_function(): + #this comment will be removed in ast form codeflash_output = some_function() assert codeflash_output is not None """ @@ -168,8 +166,8 @@ def test_different_time_formats(self, test_config): original_test_results = TestResults() optimized_test_results = TestResults() - original_test_results.add(self.create_test_invocation("test_function", original_time)) - optimized_test_results.add(self.create_test_invocation("test_function", optimized_time)) + original_test_results.add(self.create_test_invocation("test_function", original_time, iteration_id='0')) + optimized_test_results.add(self.create_test_invocation("test_function", optimized_time, iteration_id='0')) original_runtimes = original_test_results.usable_runtime_data_by_test_case() optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() @@ -233,7 +231,7 @@ def test_partial_test_results(self, test_config): original_test_results = TestResults() optimized_test_results = TestResults() - original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000)) + original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000, iteration_id='0')) # No optimized results original_runtimes = original_test_results.usable_runtime_data_by_test_case() optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() @@ -266,13 +264,13 @@ def test_multiple_runtimes_uses_minimum(self, test_config): optimized_test_results = TestResults() # Add multiple runs with different runtimes - original_test_results.add(self.create_test_invocation("test_bubble_sort", 600_000, loop_index=1)) - original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000, loop_index=2)) - original_test_results.add(self.create_test_invocation("test_bubble_sort", 550_000, loop_index=3)) + original_test_results.add(self.create_test_invocation("test_bubble_sort", 600_000, loop_index=1,iteration_id='0')) + original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000, loop_index=2,iteration_id='0')) + original_test_results.add(self.create_test_invocation("test_bubble_sort", 550_000, loop_index=3,iteration_id='0')) - optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 350_000, loop_index=1)) - optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000, loop_index=2)) - optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 320_000, loop_index=3)) + optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 350_000, loop_index=1,iteration_id='0')) + optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000, loop_index=2,iteration_id='0')) + optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 320_000, loop_index=3,iteration_id='0')) original_runtimes = original_test_results.usable_runtime_data_by_test_case() optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() @@ -304,8 +302,8 @@ def test_no_codeflash_output_assignment(self, test_config): original_test_results = TestResults() optimized_test_results = TestResults() - original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000)) - optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000)) + original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000,iteration_id='-1')) + optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000,iteration_id='-1')) original_runtimes = original_test_results.usable_runtime_data_by_test_case() optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() @@ -320,9 +318,9 @@ def test_no_codeflash_output_assignment(self, test_config): def test_invalid_python_code_handling(self, test_config): """Test behavior when test source code is invalid Python.""" test_source = """def test_bubble_sort(: - codeflash_output = bubble_sort([3, 1, 2]) + codeflash_output = bubble_sort([3, 1, 2]) assert codeflash_output == [1, 2, 3] -""" # Invalid syntax: extra colon +""" # Invalid syntax: extra indentation generated_test = GeneratedTests( generated_original_test_source=test_source, @@ -338,8 +336,8 @@ def test_invalid_python_code_handling(self, test_config): original_test_results = TestResults() optimized_test_results = TestResults() - original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000)) - optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000)) + original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000,iteration_id='0')) + optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000,iteration_id='0')) original_runtimes = original_test_results.usable_runtime_data_by_test_case() optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() @@ -359,6 +357,9 @@ def test_multiple_generated_tests(self, test_config): """ test_source_2 = """def test_quick_sort(): + a=1 + b=2 + c=3 codeflash_output = quick_sort([5, 2, 8]) assert codeflash_output == [2, 5, 8] """ @@ -385,11 +386,11 @@ def test_multiple_generated_tests(self, test_config): original_test_results = TestResults() optimized_test_results = TestResults() - original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000)) - original_test_results.add(self.create_test_invocation("test_quick_sort", 800_000)) + original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000,iteration_id='0')) + original_test_results.add(self.create_test_invocation("test_quick_sort", 800_000,iteration_id='3')) - optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000)) - optimized_test_results.add(self.create_test_invocation("test_quick_sort", 600_000)) + optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000,iteration_id='0')) + optimized_test_results.add(self.create_test_invocation("test_quick_sort", 600_000,iteration_id='3')) original_runtimes = original_test_results.usable_runtime_data_by_test_case() optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() @@ -430,8 +431,8 @@ def test_preserved_test_attributes(self, test_config): original_test_results = TestResults() optimized_test_results = TestResults() - original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000)) - optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000)) + original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000,iteration_id='0')) + optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000,iteration_id='0')) original_runtimes = original_test_results.usable_runtime_data_by_test_case() optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() @@ -472,8 +473,8 @@ def test_multistatement_line_handling(self, test_config): original_test_results = TestResults() optimized_test_results = TestResults() - original_test_results.add(self.create_test_invocation("test_mutation_of_input", 19_000)) # 19μs - optimized_test_results.add(self.create_test_invocation("test_mutation_of_input", 14_000)) # 14μs + original_test_results.add(self.create_test_invocation("test_mutation_of_input", 19_000,iteration_id='1')) # 19μs + optimized_test_results.add(self.create_test_invocation("test_mutation_of_input", 14_000,iteration_id='1')) # 14μs original_runtimes = original_test_results.usable_runtime_data_by_test_case() optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() @@ -600,16 +601,23 @@ def test_add_runtime_comments_multiple_assignments(self, test_config): generated_tests = GeneratedTestsList(generated_tests=[generated_test]) - invocation_id = InvocationId( + invocation_id1 = InvocationId( test_module_path="tests.test_module", test_class_name=None, test_function_name="test_function", function_getting_tested="some_function", - iteration_id="0", + iteration_id="1", + ) + invocation_id2 = InvocationId( + test_module_path="tests.test_module", + test_class_name=None, + test_function_name="test_function", + function_getting_tested="another_function", + iteration_id="3", ) - original_runtimes = {invocation_id: [1500000000]} # 1.5s in nanoseconds - optimized_runtimes = {invocation_id: [750000000]} # 0.75s in nanoseconds + original_runtimes = {invocation_id1: [1500000000], invocation_id2: [10]} # 1.5s in nanoseconds + optimized_runtimes = {invocation_id1: [750000000], invocation_id2: [5]} # 0.75s in nanoseconds result = add_runtime_comments_to_generated_tests( test_config, generated_tests, original_runtimes, optimized_runtimes @@ -619,7 +627,7 @@ def test_add_runtime_comments_multiple_assignments(self, test_config): setup_data = prepare_test() codeflash_output = some_function() # 1.50s -> 750ms (100% faster) assert codeflash_output == expected - codeflash_output = another_function() # 1.50s -> 750ms (100% faster) + codeflash_output = another_function() # 10ns -> 5ns (100% faster) assert codeflash_output == expected2 ''' @@ -777,6 +785,8 @@ def test_add_runtime_comments_performance_regression(self, test_config): test_source = '''def test_function(): codeflash_output = some_function() assert codeflash_output == expected + codeflash_output = some_function() + assert codeflash_output == expected ''' generated_test = GeneratedTests( @@ -789,7 +799,7 @@ def test_add_runtime_comments_performance_regression(self, test_config): generated_tests = GeneratedTestsList(generated_tests=[generated_test]) - invocation_id = InvocationId( + invocation_id1 = InvocationId( test_module_path="tests.test_module", test_class_name=None, test_function_name="test_function", @@ -797,8 +807,16 @@ def test_add_runtime_comments_performance_regression(self, test_config): iteration_id="0", ) - original_runtimes = {invocation_id: [1000000000]} # 1s - optimized_runtimes = {invocation_id: [1500000000]} # 1.5s (slower!) + invocation_id2 = InvocationId( + test_module_path="tests.test_module", + test_class_name=None, + test_function_name="test_function", + function_getting_tested="some_function", + iteration_id="2", + ) + + original_runtimes = {invocation_id1: [1000000000], invocation_id2: [2]} # 1s + optimized_runtimes = {invocation_id1: [1500000000], invocation_id2: [1]} # 1.5s (slower!) result = add_runtime_comments_to_generated_tests( test_config, generated_tests, original_runtimes, optimized_runtimes @@ -807,6 +825,8 @@ def test_add_runtime_comments_performance_regression(self, test_config): expected_source = '''def test_function(): codeflash_output = some_function() # 1.00s -> 1.50s (33.3% slower) assert codeflash_output == expected + codeflash_output = some_function() # 2ns -> 1ns (100% faster) + assert codeflash_output == expected ''' assert len(result.generated_tests) == 1