diff --git a/codeflash/code_utils/edit_generated_tests.py b/codeflash/code_utils/edit_generated_tests.py index 0f6c179c9..47ee1432d 100644 --- a/codeflash/code_utils/edit_generated_tests.py +++ b/codeflash/code_utils/edit_generated_tests.py @@ -11,8 +11,10 @@ from codeflash.cli_cmds.console import logger from codeflash.code_utils.time_utils import format_perf, format_time -from codeflash.models.models import GeneratedTests, GeneratedTestsList +from codeflash.models.models import (GeneratedTests, GeneratedTestsList, + InvocationId) from codeflash.result.critic import performance_gain +from codeflash.verification.verification_utils import TestConfig if TYPE_CHECKING: from codeflash.models.models import InvocationId @@ -90,7 +92,35 @@ def add_runtime_comments_to_generated_tests( module_root = test_cfg.project_root_path rel_tests_root = tests_root.relative_to(module_root) - # TODO: reduce for loops to one + # ---- Preindex invocation results for O(1) matching ------- + # (rel_path, qualified_name, cfo_loc) -> list[runtimes] + def _make_index(invocations): + index = {} + for invocation_id, runtimes in invocations.items(): + test_class = invocation_id.test_class_name + test_func = invocation_id.test_function_name + q_name = f"{test_class}.{test_func}" if test_class else test_func + rel_path = Path(invocation_id.test_module_path.replace(".", os.sep)).with_suffix(".py") + # Defensive: sometimes path processing can fail, fallback to string + try: + rel_path = rel_path.relative_to(rel_tests_root) + except Exception: + rel_path = str(rel_path) + # Get CFO location integer + try: + cfo_loc = int(invocation_id.iteration_id.split("_")[0]) + except Exception: + cfo_loc = None + key = (str(rel_path), q_name, cfo_loc) + if key not in index: + index[key] = [] + index[key].extend(runtimes) + return index + + orig_index = _make_index(original_runtimes) + opt_index = _make_index(optimized_runtimes) + + # Optimized fast CST visitor base class RuntimeCommentTransformer(cst.CSTTransformer): def __init__( self, qualified_name: str, module: cst.Module, test: GeneratedTests, tests_root: Path, rel_tests_root: Path @@ -104,104 +134,66 @@ def __init__( self.cfo_locs: list[int] = [] self.cfo_idx_loc_to_look_at: int = -1 self.name = qualified_name.split(".")[-1] + # Precompute test-local file relative paths for efficiency + self.test_rel_behavior = str(test.behavior_file_path.relative_to(tests_root)) + self.test_rel_perf = str(test.perf_file_path.relative_to(tests_root)) def visit_ClassDef(self, node: cst.ClassDef) -> None: - # Track when we enter a class self.context_stack.append(node.name.value) - def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002 - # Pop the context when we leave a class + def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: self.context_stack.pop() return updated_node def visit_FunctionDef(self, node: cst.FunctionDef) -> None: - # convert function body to ast normalized string and find occurrences of codeflash_output + # This could be optimized further if you access CFO assignments via CST body_code = dedent(self.module.code_for_node(node.body)) normalized_body_code = ast.unparse(ast.parse(body_code)) - self.cfo_locs = sorted( - find_codeflash_output_assignments(qualified_name, normalized_body_code) - ) # sorted in order we will encounter them + self.cfo_locs = sorted(find_codeflash_output_assignments(qualified_name, normalized_body_code)) self.cfo_idx_loc_to_look_at = -1 self.context_stack.append(node.name.value) - def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002 - # Pop the context when we leave a function + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: self.context_stack.pop() return updated_node def leave_SimpleStatementLine( - self, - original_node: cst.SimpleStatementLine, # noqa: ARG002 - updated_node: cst.SimpleStatementLine, + self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine ) -> cst.SimpleStatementLine: - # Check if this statement line contains a call to self.name - if self._contains_myfunc_call(updated_node): # type: ignore[no-untyped-call] - # Find matching test cases by looking for this test function name in the test results + # Fast skip before deep call tree walk by screening for Name nodes + if self._contains_myfunc_call(updated_node): self.cfo_idx_loc_to_look_at += 1 - matching_original_times = [] - matching_optimized_times = [] - # TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name + invocationid - for invocation_id, runtimes in original_runtimes.items(): - # get position here and match in if condition - qualified_name = ( - invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator] - if invocation_id.test_class_name - else invocation_id.test_function_name - ) - rel_path = ( - Path(invocation_id.test_module_path.replace(".", os.sep)) - .with_suffix(".py") - .relative_to(self.rel_tests_root) - ) - if ( - qualified_name == ".".join(self.context_stack) - and rel_path - in [ - self.test.behavior_file_path.relative_to(self.tests_root), - self.test.perf_file_path.relative_to(self.tests_root), - ] - and int(invocation_id.iteration_id.split("_")[0]) == self.cfo_locs[self.cfo_idx_loc_to_look_at] # type:ignore[union-attr] - ): - matching_original_times.extend(runtimes) - - for invocation_id, runtimes in optimized_runtimes.items(): - # get position here and match in if condition - qualified_name = ( - invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator] - if invocation_id.test_class_name - else invocation_id.test_function_name - ) - rel_path = ( - Path(invocation_id.test_module_path.replace(".", os.sep)) - .with_suffix(".py") - .relative_to(self.rel_tests_root) - ) - if ( - qualified_name == ".".join(self.context_stack) - and rel_path - in [ - self.test.behavior_file_path.relative_to(self.tests_root), - self.test.perf_file_path.relative_to(self.tests_root), - ] - and int(invocation_id.iteration_id.split("_")[0]) == self.cfo_locs[self.cfo_idx_loc_to_look_at] # type:ignore[union-attr] - ): - matching_optimized_times.extend(runtimes) - - if matching_original_times and matching_optimized_times: - original_time = min(matching_original_times) - optimized_time = min(matching_optimized_times) + if self.cfo_idx_loc_to_look_at >= len(self.cfo_locs): + return updated_node # Defensive, should never happen + + cfo_loc = self.cfo_locs[self.cfo_idx_loc_to_look_at] + + qualified_name_chain = ".".join(self.context_stack) + # Try both behavior and perf as possible locations; both are strings + possible_paths = {self.test_rel_behavior, self.test_rel_perf} + + # Form index key(s) + matching_original = [] + matching_optimized = [] + + for rel_path_str in possible_paths: + key = (rel_path_str, qualified_name_chain, cfo_loc) + if key in orig_index: + matching_original.extend(orig_index[key]) + if key in opt_index: + matching_optimized.extend(opt_index[key]) + if matching_original and matching_optimized: + original_time = min(matching_original) + optimized_time = min(matching_optimized) if original_time != 0 and optimized_time != 0: - perf_gain = format_perf( + perf_gain_str = format_perf( abs( performance_gain(original_runtime_ns=original_time, optimized_runtime_ns=optimized_time) * 100 ) ) status = "slower" if optimized_time > original_time else "faster" - # Create the runtime comment - comment_text = ( - f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})" - ) + comment_text = f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain_str}% {status})" return updated_node.with_changes( trailing_whitespace=cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(" "), @@ -211,43 +203,37 @@ def leave_SimpleStatementLine( ) return updated_node - def _contains_myfunc_call(self, node): # type: ignore[no-untyped-def] # noqa : ANN202, ANN001 + def _contains_myfunc_call(self, node): """Recursively search for any Call node in the statement whose function is named self.name (including obj.myfunc).""" + # IMPORTANT micro-optimization: early abort using an exception + class Found(Exception): + pass + class Finder(cst.CSTVisitor): - def __init__(self, name: str) -> None: - super().__init__() - self.found = False + def __init__(self, name): self.name = name - def visit_Call(self, call_node) -> None: # type: ignore[no-untyped-def] # noqa : ANN001 + def visit_Call(self, call_node): func_expr = call_node.func - if isinstance(func_expr, cst.Name): - if func_expr.value == self.name: - self.found = True - elif isinstance(func_expr, cst.Attribute): # noqa : SIM102 - if func_expr.attr.value == self.name: - self.found = True - - finder = Finder(self.name) - node.visit(finder) - return finder.found - - # Process each generated test + if (isinstance(func_expr, cst.Name) and func_expr.value == self.name) or ( + isinstance(func_expr, cst.Attribute) and func_expr.attr.value == self.name + ): + raise Found + + try: + node.visit(Finder(self.name)) + except Found: + return True + return False + modified_tests = [] for test in generated_tests.generated_tests: try: - # Parse the test source code tree = cst.parse_module(test.generated_original_test_source) - # Transform the tree to add runtime comments - # qualified_name: str, module: cst.Module, test: GeneratedTests, tests_root: Path, rel_tests_root: Path transformer = RuntimeCommentTransformer(qualified_name, tree, test, tests_root, rel_tests_root) modified_tree = tree.visit(transformer) - - # Convert back to source code modified_source = modified_tree.code - - # Create a new GeneratedTests object with the modified source modified_test = GeneratedTests( generated_original_test_source=modified_source, instrumented_behavior_test_source=test.instrumented_behavior_test_source, @@ -257,7 +243,6 @@ def visit_Call(self, call_node) -> None: # type: ignore[no-untyped-def] # noqa ) modified_tests.append(modified_test) except Exception as e: - # If parsing fails, keep the original test logger.debug(f"Failed to add runtime comments to test: {e}") modified_tests.append(test)