diff --git a/codeflash/code_utils/edit_generated_tests.py b/codeflash/code_utils/edit_generated_tests.py index 35150e0da..e5978f0fe 100644 --- a/codeflash/code_utils/edit_generated_tests.py +++ b/codeflash/code_utils/edit_generated_tests.py @@ -5,7 +5,7 @@ import re from pathlib import Path from textwrap import dedent -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING import libcst as cst @@ -50,231 +50,202 @@ class CfoVisitor(ast.NodeVisitor): and reports their location relative to the function they're in. """ - def __init__(self, source_code: str) -> None: + def __init__(self, function_name: str, source_code: str) -> None: self.source_lines = source_code.splitlines() + self.name = function_name self.results: list[int] = [] # map actual line number to line number in ast - def _is_codeflash_output_target(self, target: Union[ast.expr, list]) -> bool: # type: ignore[type-arg] - """Check if the assignment target is the variable 'codeflash_output'.""" - if isinstance(target, ast.Name): - return target.id == "codeflash_output" - if isinstance(target, (ast.Tuple, ast.List)): - # Handle tuple/list unpacking: a, codeflash_output, b = values - return any(self._is_codeflash_output_target(elt) for elt in target.elts) - if isinstance(target, (ast.Subscript, ast.Attribute)): - # Not a simple variable assignment - return False - return False - - def _record_assignment(self, node: ast.AST) -> None: - """Record an assignment to codeflash_output.""" - relative_line = node.lineno - 1 # type: ignore[attr-defined] - self.results.append(relative_line) - - def visit_Assign(self, node: ast.Assign) -> None: - """Visit assignment statements: codeflash_output = value.""" - for target in node.targets: - if self._is_codeflash_output_target(target): - self._record_assignment(node) - break + def visit_Call(self, node): # type: ignore[no-untyped-def] # noqa: ANN201, ANN001 + """Detect fn calls.""" + func_name = self._get_called_func_name(node.func) # type: ignore[no-untyped-call] + if func_name == self.name: + self.results.append(node.lineno - 1) self.generic_visit(node) - def visit_AnnAssign(self, node: ast.AnnAssign) -> None: - """Visit annotated assignments: codeflash_output: int = value.""" - if self._is_codeflash_output_target(node.target): - self._record_assignment(node) - self.generic_visit(node) - - def visit_AugAssign(self, node: ast.AugAssign) -> None: - """Visit augmented assignments: codeflash_output += value.""" - if self._is_codeflash_output_target(node.target): - self._record_assignment(node) - self.generic_visit(node) - - def visit_NamedExpr(self, node: ast.NamedExpr) -> None: - """Visit walrus operator: (codeflash_output := value).""" - if isinstance(node.target, ast.Name) and node.target.id == "codeflash_output": - self._record_assignment(node) - self.generic_visit(node) - - def visit_For(self, node: ast.For) -> None: - """Visit for loops: for codeflash_output in iterable.""" - if self._is_codeflash_output_target(node.target): - self._record_assignment(node) - self.generic_visit(node) - - def visit_comprehension(self, node: ast.comprehension) -> None: - """Visit comprehensions: [x for codeflash_output in iterable].""" - if self._is_codeflash_output_target(node.target): - # Comprehensions don't have line numbers, so we skip recording - pass - self.generic_visit(node) - - def visit_With(self, node: ast.With) -> None: - """Visit with statements: with expr as codeflash_output.""" - for item in node.items: - if item.optional_vars and self._is_codeflash_output_target(item.optional_vars): - self._record_assignment(node) - break - self.generic_visit(node) - - def visit_ExceptHandler(self, node: ast.ExceptHandler) -> None: - """Visit except handlers: except Exception as codeflash_output.""" - if node.name == "codeflash_output": - self._record_assignment(node) - self.generic_visit(node) + def _get_called_func_name(self, node): # type: ignore[no-untyped-def] # noqa: ANN001, ANN202 + """Return name of called fn.""" + if isinstance(node, ast.Name): + return node.id + if isinstance(node, ast.Attribute): + return node.attr + return None -def find_codeflash_output_assignments(source_code: str) -> list[int]: +def find_codeflash_output_assignments(function_name: str, source_code: str) -> list[int]: tree = ast.parse(source_code) - visitor = CfoVisitor(source_code) + visitor = CfoVisitor(function_name, source_code) visitor.visit(tree) return visitor.results -def add_runtime_comments_to_generated_tests( - test_cfg: TestConfig, - generated_tests: GeneratedTestsList, - original_runtimes: dict[InvocationId, list[int]], - optimized_runtimes: dict[InvocationId, list[int]], -) -> GeneratedTestsList: - """Add runtime performance comments to function calls in generated tests.""" - tests_root = test_cfg.tests_root - module_root = test_cfg.project_root_path - rel_tests_root = tests_root.relative_to(module_root) - - # TODO: reduce for loops to one - class RuntimeCommentTransformer(cst.CSTTransformer): - def __init__(self, module: cst.Module, test: GeneratedTests, tests_root: Path, rel_tests_root: Path) -> None: - super().__init__() - self.test = test - self.context_stack: list[str] = [] - self.tests_root = tests_root - self.rel_tests_root = rel_tests_root - self.module = module - self.cfo_locs: list[int] = [] - self.cfo_idx_loc_to_look_at: int = -1 - - def visit_ClassDef(self, node: cst.ClassDef) -> None: - # Track when we enter a class - self.context_stack.append(node.name.value) - - def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002 - # Pop the context when we leave a class - self.context_stack.pop() - return updated_node - - def visit_FunctionDef(self, node: cst.FunctionDef) -> None: - # convert function body to ast normalized string and find occurrences of codeflash_output - body_code = dedent(self.module.code_for_node(node.body)) - normalized_body_code = ast.unparse(ast.parse(body_code)) - self.cfo_locs = sorted( - find_codeflash_output_assignments(normalized_body_code) - ) # sorted in order we will encounter them - self.cfo_idx_loc_to_look_at = -1 - self.context_stack.append(node.name.value) - - def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002 - # Pop the context when we leave a function - self.context_stack.pop() - return updated_node - - def leave_SimpleStatementLine( - self, - original_node: cst.SimpleStatementLine, # noqa: ARG002 - updated_node: cst.SimpleStatementLine, - ) -> cst.SimpleStatementLine: - # Look for assignment statements that assign to codeflash_output - # Handle both single statements and multiple statements on one line - codeflash_assignment_found = False - for stmt in updated_node.body: - if isinstance(stmt, cst.Assign) and ( - len(stmt.targets) == 1 - and isinstance(stmt.targets[0].target, cst.Name) - and stmt.targets[0].target.value == "codeflash_output" +class Finder(cst.CSTVisitor): + def __init__(self, name: str) -> None: + super().__init__() + self.found = False + self.name = name + + def visit_Call(self, call_node) -> None: # type: ignore[no-untyped-def] # noqa : ANN001 + func_expr = call_node.func + if isinstance(func_expr, cst.Name): + if func_expr.value == self.name: + self.found = True + elif isinstance(func_expr, cst.Attribute): # noqa : SIM102 + if func_expr.attr.value == self.name: + self.found = True + + +# TODO: reduce for loops to one +class RuntimeCommentTransformer(cst.CSTTransformer): + def __init__( + self, + qualified_name: str, + module: cst.Module, + test: GeneratedTests, + tests_root: Path, + rel_tests_root: Path, + original_runtimes: dict[InvocationId, list[int]], + optimized_runtimes: dict[InvocationId, list[int]], + ) -> None: + super().__init__() + self.test = test + self.context_stack: list[str] = [] + self.tests_root = tests_root + self.rel_tests_root = rel_tests_root + self.module = module + self.cfo_locs: list[int] = [] + self.cfo_idx_loc_to_look_at: int = -1 + self.name = qualified_name.split(".")[-1] + self.original_runtimes = original_runtimes + self.optimized_runtimes = optimized_runtimes + + def visit_ClassDef(self, node: cst.ClassDef) -> None: + # Track when we enter a class + self.context_stack.append(node.name.value) + + def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002 + # Pop the context when we leave a class + self.context_stack.pop() + return updated_node + + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + # convert function body to ast normalized string and find occurrences of codeflash_output + body_code = dedent(self.module.code_for_node(node.body)) + normalized_body_code = ast.unparse(ast.parse(body_code)) + self.cfo_locs = sorted( + find_codeflash_output_assignments(self.name, normalized_body_code) + ) # sorted in order we will encounter them + self.cfo_idx_loc_to_look_at = -1 + self.context_stack.append(node.name.value) + + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002 + # Pop the context when we leave a function + self.context_stack.pop() + return updated_node + + def leave_SimpleStatementLine( + self, + original_node: cst.SimpleStatementLine, # noqa: ARG002 + updated_node: cst.SimpleStatementLine, + ) -> cst.SimpleStatementLine: + # Check if this statement line contains a call to self.name + if self._contains_myfunc_call(updated_node): # type: ignore[no-untyped-call] + # Find matching test cases by looking for this test function name in the test results + self.cfo_idx_loc_to_look_at += 1 + matching_original_times = [] + matching_optimized_times = [] + # TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name + invocationid + for invocation_id, runtimes in self.original_runtimes.items(): + # get position here and match in if condition + qualified_name = ( + invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator] + if invocation_id.test_class_name + else invocation_id.test_function_name + ) + rel_path = ( + Path(invocation_id.test_module_path.replace(".", os.sep)) + .with_suffix(".py") + .relative_to(self.rel_tests_root) + ) + if ( + qualified_name == ".".join(self.context_stack) + and rel_path + in [ + self.test.behavior_file_path.relative_to(self.tests_root), + self.test.perf_file_path.relative_to(self.tests_root), + ] + and int(invocation_id.iteration_id.split("_")[0]) == self.cfo_locs[self.cfo_idx_loc_to_look_at] # type:ignore[union-attr] ): - codeflash_assignment_found = True - break - - if codeflash_assignment_found: - # Find matching test cases by looking for this test function name in the test results - self.cfo_idx_loc_to_look_at += 1 - matching_original_times = [] - matching_optimized_times = [] - # TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name + invocationid - for invocation_id, runtimes in original_runtimes.items(): - # get position here and match in if condition - qualified_name = ( - invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator] - if invocation_id.test_class_name - else invocation_id.test_function_name - ) - rel_path = ( - Path(invocation_id.test_module_path.replace(".", os.sep)) - .with_suffix(".py") - .relative_to(self.rel_tests_root) - ) - if ( - qualified_name == ".".join(self.context_stack) - and rel_path - in [ - self.test.behavior_file_path.relative_to(self.tests_root), - self.test.perf_file_path.relative_to(self.tests_root), - ] - and int(invocation_id.iteration_id.split("_")[0]) == self.cfo_locs[self.cfo_idx_loc_to_look_at] # type:ignore[union-attr] - ): - matching_original_times.extend(runtimes) - - for invocation_id, runtimes in optimized_runtimes.items(): - # get position here and match in if condition - qualified_name = ( - invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator] - if invocation_id.test_class_name - else invocation_id.test_function_name + matching_original_times.extend(runtimes) + + for invocation_id, runtimes in self.optimized_runtimes.items(): + # get position here and match in if condition + qualified_name = ( + invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator] + if invocation_id.test_class_name + else invocation_id.test_function_name + ) + rel_path = ( + Path(invocation_id.test_module_path.replace(".", os.sep)) + .with_suffix(".py") + .relative_to(self.rel_tests_root) + ) + if ( + qualified_name == ".".join(self.context_stack) + and rel_path + in [ + self.test.behavior_file_path.relative_to(self.tests_root), + self.test.perf_file_path.relative_to(self.tests_root), + ] + and int(invocation_id.iteration_id.split("_")[0]) == self.cfo_locs[self.cfo_idx_loc_to_look_at] # type:ignore[union-attr] + ): + matching_optimized_times.extend(runtimes) + + if matching_original_times and matching_optimized_times: + original_time = min(matching_original_times) + optimized_time = min(matching_optimized_times) + if original_time != 0 and optimized_time != 0: + perf_gain = format_perf( + abs( + performance_gain(original_runtime_ns=original_time, optimized_runtime_ns=optimized_time) + * 100 + ) ) - rel_path = ( - Path(invocation_id.test_module_path.replace(".", os.sep)) - .with_suffix(".py") - .relative_to(self.rel_tests_root) + status = "slower" if optimized_time > original_time else "faster" + # Create the runtime comment + comment_text = ( + f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})" ) - if ( - qualified_name == ".".join(self.context_stack) - and rel_path - in [ - self.test.behavior_file_path.relative_to(self.tests_root), - self.test.perf_file_path.relative_to(self.tests_root), - ] - and int(invocation_id.iteration_id.split("_")[0]) == self.cfo_locs[self.cfo_idx_loc_to_look_at] # type:ignore[union-attr] - ): - matching_optimized_times.extend(runtimes) - - if matching_original_times and matching_optimized_times: - original_time = min(matching_original_times) - optimized_time = min(matching_optimized_times) - if original_time != 0 and optimized_time != 0: - perf_gain = format_perf( - abs( - performance_gain(original_runtime_ns=original_time, optimized_runtime_ns=optimized_time) - * 100 - ) - ) - status = "slower" if optimized_time > original_time else "faster" - # Create the runtime comment - comment_text = ( - f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})" - ) - - # Add comment to the trailing whitespace - new_trailing_whitespace = cst.TrailingWhitespace( + return updated_node.with_changes( + trailing_whitespace=cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(" "), comment=cst.Comment(comment_text), newline=updated_node.trailing_whitespace.newline, ) + ) + return updated_node - return updated_node.with_changes(trailing_whitespace=new_trailing_whitespace) + def _contains_myfunc_call(self, node): # type: ignore[no-untyped-def] # noqa : ANN202, ANN001 + """Recursively search for any Call node in the statement whose function is named self.name (including obj.myfunc).""" + finder = Finder(self.name) + node.visit(finder) + return finder.found - return updated_node +def add_runtime_comments_to_generated_tests( + qualified_name: str, + test_cfg: TestConfig, + generated_tests: GeneratedTestsList, + original_runtimes: dict[InvocationId, list[int]], + optimized_runtimes: dict[InvocationId, list[int]], +) -> GeneratedTestsList: + """Add runtime performance comments to function calls in generated tests.""" + tests_root = test_cfg.tests_root + module_root = test_cfg.project_root_path + try: + rel_tests_root = tests_root.relative_to(module_root) + except Exception as e: + logger.debug(e) + return generated_tests # Process each generated test modified_tests = [] for test in generated_tests.generated_tests: @@ -282,7 +253,10 @@ def leave_SimpleStatementLine( # Parse the test source code tree = cst.parse_module(test.generated_original_test_source) # Transform the tree to add runtime comments - transformer = RuntimeCommentTransformer(tree, test, tests_root, rel_tests_root) + # qualified_name: str, module: cst.Module, test: GeneratedTests, tests_root: Path, rel_tests_root: Path + transformer = RuntimeCommentTransformer( + qualified_name, tree, test, tests_root, rel_tests_root, original_runtimes, optimized_runtimes + ) modified_tree = tree.visit(transformer) # Convert back to source code diff --git a/codeflash/lsp/server.py b/codeflash/lsp/server.py index 1a303c13c..a018786a8 100644 --- a/codeflash/lsp/server.py +++ b/codeflash/lsp/server.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any -from lsprotocol.types import INITIALIZE, MessageType, LogMessageParams +from lsprotocol.types import INITIALIZE, LogMessageParams, MessageType from pygls import uris from pygls.protocol import LanguageServerProtocol, lsp_method from pygls.server import LanguageServer @@ -58,21 +58,22 @@ def initialize_optimizer(self, config_file: Path) -> None: def show_message_log(self, message: str, message_type: str) -> None: """Send a log message to the client's output channel. - + Args: message: The message to log message_type: String type - "Info", "Warning", "Error", or "Log" + """ # Convert string message type to LSP MessageType enum type_mapping = { "Info": MessageType.Info, - "Warning": MessageType.Warning, + "Warning": MessageType.Warning, "Error": MessageType.Error, - "Log": MessageType.Log + "Log": MessageType.Log, } - + lsp_message_type = type_mapping.get(message_type, MessageType.Info) - + # Send log message to client (appears in output channel) log_params = LogMessageParams(type=lsp_message_type, message=message) self.lsp.notify("window/logMessage", log_params) diff --git a/codeflash/lsp/server_entry.py b/codeflash/lsp/server_entry.py index 841d18f84..48b13fb6c 100644 --- a/codeflash/lsp/server_entry.py +++ b/codeflash/lsp/server_entry.py @@ -1,4 +1,5 @@ -"""This script is the dedicated entry point for the Codeflash Language Server. +"""The following script is the dedicated entry point for the Codeflash Language Server. + It initializes the server and redirects its logs to stderr so that the VS Code client can display them in the output channel. @@ -13,7 +14,7 @@ # Configure logging to stderr for VS Code output channel -def setup_logging(): +def setup_logging(): # noqa : ANN201 # Clear any existing handlers to prevent conflicts root_logger = logging.getLogger() root_logger.handlers.clear() diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 82cf4bc57..8772bb9e5 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1012,15 +1012,20 @@ def find_and_process_best_optimization( optimized_runtime_by_test = ( best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case() ) + qualified_name = self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root) # Add runtime comments to generated tests before creating the PR generated_tests = add_runtime_comments_to_generated_tests( - self.test_cfg, generated_tests, original_runtime_by_test, optimized_runtime_by_test + qualified_name, + self.test_cfg, + generated_tests, + original_runtime_by_test, + optimized_runtime_by_test, ) generated_tests_str = "\n\n".join( [test.generated_original_test_source for test in generated_tests.generated_tests] ) existing_tests = existing_tests_source_for( - self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root), + qualified_name, function_to_all_tests, test_cfg=self.test_cfg, original_runtimes_all=original_runtime_by_test, diff --git a/codeflash/result/create_pr.py b/codeflash/result/create_pr.py index a08875a4f..d96a34f86 100644 --- a/codeflash/result/create_pr.py +++ b/codeflash/result/create_pr.py @@ -52,9 +52,13 @@ def existing_tests_source_for( # TODO confirm that original and optimized have the same keys all_invocation_ids = original_runtimes_all.keys() | optimized_runtimes_all.keys() for invocation_id in all_invocation_ids: - rel_path = ( - Path(invocation_id.test_module_path.replace(".", os.sep)).with_suffix(".py").relative_to(rel_tests_root) - ) + try: + rel_path = ( + Path(invocation_id.test_module_path.replace(".", os.sep)).with_suffix(".py").relative_to(rel_tests_root) + ) + except Exception as e: + logger.debug(e) + continue if rel_path not in non_generated_tests: continue if rel_path not in original_tests_to_runtimes: diff --git a/tests/test_add_runtime_comments.py b/tests/test_add_runtime_comments.py index 71f1d7566..1afc803b1 100644 --- a/tests/test_add_runtime_comments.py +++ b/tests/test_add_runtime_comments.py @@ -53,6 +53,7 @@ def test_basic_runtime_comment_addition(self, test_config): assert codeflash_output == [1, 2, 3] """ + qualified_name = "bubble_sort" generated_test = GeneratedTests( generated_original_test_source=test_source, instrumented_behavior_test_source="", @@ -75,7 +76,7 @@ def test_basic_runtime_comment_addition(self, test_config): original_runtimes = original_test_results.usable_runtime_data_by_test_case() optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() # Test the functionality - result = add_runtime_comments_to_generated_tests(test_config, generated_tests, original_runtimes, optimized_runtimes) + result = add_runtime_comments_to_generated_tests(qualified_name, test_config, generated_tests, original_runtimes, optimized_runtimes) # Check that comments were added modified_source = result.generated_tests[0].generated_original_test_source @@ -85,7 +86,7 @@ def test_basic_runtime_comment_addition(self, test_config): def test_multiple_test_functions(self, test_config): """Test handling multiple test functions in the same file.""" test_source = """def test_bubble_sort(): - codeflash_output = bubble_sort([3, 1, 2]) + codeflash_output = quick_sort([3, 1, 2]) assert codeflash_output == [1, 2, 3] def test_quick_sort(): @@ -95,7 +96,7 @@ def test_quick_sort(): def helper_function(): return "not a test" """ - + qualified_name = "quick_sort" generated_test = GeneratedTests( generated_original_test_source=test_source, instrumented_behavior_test_source="", @@ -121,7 +122,7 @@ def helper_function(): optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() # Test the functionality - result = add_runtime_comments_to_generated_tests(test_config, generated_tests, original_runtimes, optimized_runtimes) + result = add_runtime_comments_to_generated_tests(qualified_name, test_config, generated_tests, original_runtimes, optimized_runtimes) modified_source = result.generated_tests[0].generated_original_test_source @@ -151,7 +152,7 @@ def test_different_time_formats(self, test_config): codeflash_output = some_function() assert codeflash_output is not None """ - + qualified_name = "some_function" generated_test = GeneratedTests( generated_original_test_source=test_source, instrumented_behavior_test_source="", @@ -173,7 +174,7 @@ def test_different_time_formats(self, test_config): optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() # Test the functionality result = add_runtime_comments_to_generated_tests( - test_config, generated_tests, original_runtimes, optimized_runtimes + qualified_name, test_config, generated_tests, original_runtimes, optimized_runtimes ) modified_source = result.generated_tests[0].generated_original_test_source @@ -186,6 +187,7 @@ def test_missing_test_results(self, test_config): assert codeflash_output == [1, 2, 3] """ + qualified_name = "bubble_sort" generated_test = GeneratedTests( generated_original_test_source=test_source, instrumented_behavior_test_source="", @@ -204,7 +206,7 @@ def test_missing_test_results(self, test_config): optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() # Test the functionality - result = add_runtime_comments_to_generated_tests(test_config, generated_tests, original_runtimes, optimized_runtimes) + result = add_runtime_comments_to_generated_tests(qualified_name, test_config, generated_tests, original_runtimes, optimized_runtimes) # Check that no comments were added modified_source = result.generated_tests[0].generated_original_test_source @@ -217,6 +219,7 @@ def test_partial_test_results(self, test_config): assert codeflash_output == [1, 2, 3] """ + qualified_name = "bubble_sort" generated_test = GeneratedTests( generated_original_test_source=test_source, instrumented_behavior_test_source="", @@ -236,7 +239,7 @@ def test_partial_test_results(self, test_config): original_runtimes = original_test_results.usable_runtime_data_by_test_case() optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() # Test the functionality - result = add_runtime_comments_to_generated_tests(test_config, generated_tests, original_runtimes, optimized_runtimes) + result = add_runtime_comments_to_generated_tests(qualified_name, test_config, generated_tests, original_runtimes, optimized_runtimes) # Check that no comments were added modified_source = result.generated_tests[0].generated_original_test_source @@ -248,7 +251,7 @@ def test_multiple_runtimes_uses_minimum(self, test_config): codeflash_output = bubble_sort([3, 1, 2]) assert codeflash_output == [1, 2, 3] """ - + qualified_name = "bubble_sort" generated_test = GeneratedTests( generated_original_test_source=test_source, instrumented_behavior_test_source="", @@ -275,7 +278,7 @@ def test_multiple_runtimes_uses_minimum(self, test_config): original_runtimes = original_test_results.usable_runtime_data_by_test_case() optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() # Test the functionality - result = add_runtime_comments_to_generated_tests(test_config, generated_tests, original_runtimes, optimized_runtimes) + result = add_runtime_comments_to_generated_tests(qualified_name, test_config, generated_tests, original_runtimes, optimized_runtimes) # Check that minimum times were used (500μs -> 300μs) modified_source = result.generated_tests[0].generated_original_test_source @@ -287,7 +290,7 @@ def test_no_codeflash_output_assignment(self, test_config): result = bubble_sort([3, 1, 2]) assert result == [1, 2, 3] """ - + qualified_name = "bubble_sort" generated_test = GeneratedTests( generated_original_test_source=test_source, instrumented_behavior_test_source="", @@ -309,7 +312,7 @@ def test_no_codeflash_output_assignment(self, test_config): optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() # Test the functionality - result = add_runtime_comments_to_generated_tests(test_config, generated_tests, original_runtimes, optimized_runtimes) + result = add_runtime_comments_to_generated_tests(qualified_name, test_config, generated_tests, original_runtimes, optimized_runtimes) # Check that no comments were added (no codeflash_output assignment) modified_source = result.generated_tests[0].generated_original_test_source @@ -329,7 +332,7 @@ def test_invalid_python_code_handling(self, test_config): behavior_file_path=Path("/project/tests/test_module.py"), perf_file_path=Path("/project/tests/test_module_perf.py") ) - + qualified_name = "bubble_sort" generated_tests = GeneratedTestsList(generated_tests=[generated_test]) # Create test results @@ -343,7 +346,7 @@ def test_invalid_python_code_handling(self, test_config): optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() # Test the functionality - should handle parse error gracefully - result = add_runtime_comments_to_generated_tests(test_config, generated_tests, original_runtimes, optimized_runtimes) + result = add_runtime_comments_to_generated_tests(qualified_name, test_config, generated_tests, original_runtimes, optimized_runtimes) # Check that original test is preserved when parsing fails modified_source = result.generated_tests[0].generated_original_test_source @@ -352,7 +355,7 @@ def test_invalid_python_code_handling(self, test_config): def test_multiple_generated_tests(self, test_config): """Test handling multiple generated test objects.""" test_source_1 = """def test_bubble_sort(): - codeflash_output = bubble_sort([3, 1, 2]) + codeflash_output = quick_sort([3, 1, 2]) assert codeflash_output == [1, 2, 3] """ @@ -363,7 +366,7 @@ def test_multiple_generated_tests(self, test_config): codeflash_output = quick_sort([5, 2, 8]) assert codeflash_output == [2, 5, 8] """ - + qualified_name = "quick_sort" generated_test_1 = GeneratedTests( generated_original_test_source=test_source_1, instrumented_behavior_test_source="", @@ -396,7 +399,7 @@ def test_multiple_generated_tests(self, test_config): optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() # Test the functionality - result = add_runtime_comments_to_generated_tests(test_config, generated_tests, original_runtimes, optimized_runtimes) + result = add_runtime_comments_to_generated_tests(qualified_name, test_config, generated_tests, original_runtimes, optimized_runtimes) # Check that comments were added to both test files modified_source_1 = result.generated_tests[0].generated_original_test_source @@ -411,7 +414,7 @@ def test_preserved_test_attributes(self, test_config): codeflash_output = bubble_sort([3, 1, 2]) assert codeflash_output == [1, 2, 3] """ - + qualified_name = "bubble_sort" original_behavior_source = "behavior test source" original_perf_source = "perf test source" original_behavior_path = Path("/project/tests/test_module.py") @@ -437,7 +440,7 @@ def test_preserved_test_attributes(self, test_config): original_runtimes = original_test_results.usable_runtime_data_by_test_case() optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() # Test the functionality - result = add_runtime_comments_to_generated_tests(test_config, generated_tests, original_runtimes, optimized_runtimes) + result = add_runtime_comments_to_generated_tests(qualified_name, test_config, generated_tests, original_runtimes, optimized_runtimes) # Check that other attributes are preserved modified_test = result.generated_tests[0] @@ -458,7 +461,7 @@ def test_multistatement_line_handling(self, test_config): assert result == [1, 2, 3] assert arr == [1, 2, 3] # Input should be mutated """ - + qualified_name = "sorter" generated_test = GeneratedTests( generated_original_test_source=test_source, instrumented_behavior_test_source="", @@ -480,7 +483,7 @@ def test_multistatement_line_handling(self, test_config): optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() # Test the functionality - result = add_runtime_comments_to_generated_tests(test_config, generated_tests, original_runtimes, optimized_runtimes) + result = add_runtime_comments_to_generated_tests(qualified_name, test_config, generated_tests, original_runtimes, optimized_runtimes) # Check that comments were added to the correct line modified_source = result.generated_tests[0].generated_original_test_source @@ -504,7 +507,7 @@ def test_add_runtime_comments_simple_function(self, test_config): codeflash_output = some_function() assert codeflash_output == expected ''' - + qualified_name = "some_function" generated_test = GeneratedTests( generated_original_test_source=test_source, instrumented_behavior_test_source="", @@ -527,7 +530,7 @@ def test_add_runtime_comments_simple_function(self, test_config): optimized_runtimes = {invocation_id: [500000000, 600000000]} # 0.5s, 0.6s in nanoseconds result = add_runtime_comments_to_generated_tests( - test_config, generated_tests, original_runtimes, optimized_runtimes + qualified_name, test_config, generated_tests, original_runtimes, optimized_runtimes ) expected_source = '''def test_function(): @@ -545,7 +548,7 @@ def test_function(self): codeflash_output = some_function() assert codeflash_output == expected ''' - + qualified_name = "some_function" generated_test = GeneratedTests( generated_original_test_source=test_source, instrumented_behavior_test_source="", @@ -569,7 +572,7 @@ def test_function(self): optimized_runtimes = {invocation_id: [1000000000]} # 1s in nanoseconds result = add_runtime_comments_to_generated_tests( - test_config, generated_tests, original_runtimes, optimized_runtimes + qualified_name, test_config, generated_tests, original_runtimes, optimized_runtimes ) expected_source = '''class TestClass: @@ -589,8 +592,10 @@ def test_add_runtime_comments_multiple_assignments(self, test_config): assert codeflash_output == expected codeflash_output = another_function() assert codeflash_output == expected2 + codeflash_output = some_function() + assert codeflash_output == expected2 ''' - + qualified_name = "some_function" generated_test = GeneratedTests( generated_original_test_source=test_source, instrumented_behavior_test_source="", @@ -612,22 +617,24 @@ def test_add_runtime_comments_multiple_assignments(self, test_config): test_module_path="tests.test_module", test_class_name=None, test_function_name="test_function", - function_getting_tested="another_function", - iteration_id="3", + function_getting_tested="some_function", + iteration_id="5", ) original_runtimes = {invocation_id1: [1500000000], invocation_id2: [10]} # 1.5s in nanoseconds optimized_runtimes = {invocation_id1: [750000000], invocation_id2: [5]} # 0.75s in nanoseconds result = add_runtime_comments_to_generated_tests( - test_config, generated_tests, original_runtimes, optimized_runtimes + qualified_name, test_config, generated_tests, original_runtimes, optimized_runtimes ) expected_source = '''def test_function(): setup_data = prepare_test() codeflash_output = some_function() # 1.50s -> 750ms (100% faster) assert codeflash_output == expected - codeflash_output = another_function() # 10ns -> 5ns (100% faster) + codeflash_output = another_function() + assert codeflash_output == expected2 + codeflash_output = some_function() # 10ns -> 5ns (100% faster) assert codeflash_output == expected2 ''' @@ -640,7 +647,7 @@ def test_add_runtime_comments_no_matching_runtimes(self, test_config): codeflash_output = some_function() assert codeflash_output == expected ''' - + qualified_name = "some_function" generated_test = GeneratedTests( generated_original_test_source=test_source, instrumented_behavior_test_source="", @@ -664,7 +671,7 @@ def test_add_runtime_comments_no_matching_runtimes(self, test_config): optimized_runtimes = {invocation_id: [500000000]} result = add_runtime_comments_to_generated_tests( - test_config, generated_tests, original_runtimes, optimized_runtimes + qualified_name, test_config, generated_tests, original_runtimes, optimized_runtimes ) # Source should remain unchanged @@ -672,12 +679,16 @@ def test_add_runtime_comments_no_matching_runtimes(self, test_config): assert result.generated_tests[0].generated_original_test_source == test_source def test_add_runtime_comments_no_codeflash_output(self, test_config): - """Test that source remains unchanged when there's no codeflash_output assignment.""" + """comments will still be added if codeflash output doesnt exist""" test_source = '''def test_function(): result = some_function() assert result == expected ''' - + qualified_name = "some_function" + expected = '''def test_function(): + result = some_function() # 1.00s -> 500ms (100% faster) + assert result == expected +''' generated_test = GeneratedTests( generated_original_test_source=test_source, instrumented_behavior_test_source="", @@ -700,12 +711,12 @@ def test_add_runtime_comments_no_codeflash_output(self, test_config): optimized_runtimes = {invocation_id: [500000000]} result = add_runtime_comments_to_generated_tests( - test_config, generated_tests, original_runtimes, optimized_runtimes + qualified_name, test_config, generated_tests, original_runtimes, optimized_runtimes ) # Source should remain unchanged assert len(result.generated_tests) == 1 - assert result.generated_tests[0].generated_original_test_source == test_source + assert result.generated_tests[0].generated_original_test_source == expected def test_add_runtime_comments_multiple_tests(self, test_config): """Test adding runtime comments to multiple generated tests.""" @@ -715,10 +726,10 @@ def test_add_runtime_comments_multiple_tests(self, test_config): ''' test_source2 = '''def test_function2(): - codeflash_output = another_function() + codeflash_output = some_function() assert codeflash_output == expected ''' - + qualified_name = "some_function" generated_test1 = GeneratedTests( generated_original_test_source=test_source1, instrumented_behavior_test_source="", @@ -749,7 +760,7 @@ def test_add_runtime_comments_multiple_tests(self, test_config): test_module_path="tests.test_module2", test_class_name=None, test_function_name="test_function2", - function_getting_tested="another_function", + function_getting_tested="some_function", # not used in this test throughout the entire test file iteration_id = "0", ) @@ -763,7 +774,7 @@ def test_add_runtime_comments_multiple_tests(self, test_config): } result = add_runtime_comments_to_generated_tests( - test_config, generated_tests, original_runtimes, optimized_runtimes + qualified_name, test_config, generated_tests, original_runtimes, optimized_runtimes ) expected_source1 = '''def test_function1(): @@ -772,7 +783,7 @@ def test_add_runtime_comments_multiple_tests(self, test_config): ''' expected_source2 = '''def test_function2(): - codeflash_output = another_function() # 2.00s -> 800ms (150% faster) + codeflash_output = some_function() # 2.00s -> 800ms (150% faster) assert codeflash_output == expected ''' @@ -788,7 +799,7 @@ def test_add_runtime_comments_performance_regression(self, test_config): codeflash_output = some_function() assert codeflash_output == expected ''' - + qualified_name = "some_function" generated_test = GeneratedTests( generated_original_test_source=test_source, instrumented_behavior_test_source="", @@ -819,7 +830,7 @@ def test_add_runtime_comments_performance_regression(self, test_config): optimized_runtimes = {invocation_id1: [1500000000], invocation_id2: [1]} # 1.5s (slower!) result = add_runtime_comments_to_generated_tests( - test_config, generated_tests, original_runtimes, optimized_runtimes + qualified_name, test_config, generated_tests, original_runtimes, optimized_runtimes ) expected_source = '''def test_function():