diff --git a/codeflash/code_utils/edit_generated_tests.py b/codeflash/code_utils/edit_generated_tests.py index 8e50b1d7..11ca6d4b 100644 --- a/codeflash/code_utils/edit_generated_tests.py +++ b/codeflash/code_utils/edit_generated_tests.py @@ -60,32 +60,47 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctio return node def _process_function_def_common(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None: - self.context_stack.append(node.name) - i = len(node.body) - 1 - test_qualified_name = ".".join(self.context_stack) - key = test_qualified_name + "#" + str(self.abs_path) + context_stack = self.context_stack + context_stack.append(node.name) + test_qualified_name = ".".join(context_stack) + key_base = f"{test_qualified_name}#{self.abs_path}" + results = self.results + original_runtimes = self.original_runtimes + optimized_runtimes = self.optimized_runtimes + get_comment = self.get_comment + + # Pre-fetch these for loop, reduces attribute+dict lookup cost + node_body = node.body + i = len(node_body) - 1 while i >= 0: - line_node = node.body[i] + line_node = node_body[i] if isinstance(line_node, (ast.With, ast.For, ast.While, ast.If)): - j = len(line_node.body) - 1 + ln_body = line_node.body + j = len(ln_body) - 1 while j >= 0: - compound_line_node: ast.stmt = line_node.body[j] + compound_line_node: ast.stmt = ln_body[j] + # Collect nodes to check nodes_to_check = [compound_line_node] - nodes_to_check.extend(getattr(compound_line_node, "body", [])) - for internal_node in nodes_to_check: - if isinstance(internal_node, (ast.stmt, ast.Assign)): - inv_id = str(i) + "_" + str(j) - match_key = key + "#" + inv_id - if match_key in self.original_runtimes and match_key in self.optimized_runtimes: - self.results[internal_node.lineno] = self.get_comment(match_key) + extend_body = getattr(compound_line_node, "body", None) + if extend_body: + nodes_to_check.extend(extend_body) + inv_id = f"{i}_{j}" + match_key = f"{key_base}#{inv_id}" + if match_key in original_runtimes and match_key in optimized_runtimes: + # Slightly faster to avoid type checks in loop if possible + for internal_node in nodes_to_check: + # is ast.Assign a subclass of ast.stmt? If yes, only need ast.stmt (Assign inherits stmt). + # But original code checks for both, so preserve as-is. + if isinstance(internal_node, (ast.stmt, ast.Assign)): + results[internal_node.lineno] = get_comment(match_key) j -= 1 else: inv_id = str(i) - match_key = key + "#" + inv_id - if match_key in self.original_runtimes and match_key in self.optimized_runtimes: - self.results[line_node.lineno] = self.get_comment(match_key) + match_key = f"{key_base}#{inv_id}" + if match_key in original_runtimes and match_key in optimized_runtimes: + results[line_node.lineno] = get_comment(match_key) i -= 1 - self.context_stack.pop() + context_stack.pop() def get_fn_call_linenos(