diff --git a/codeflash/code_utils/edit_generated_tests.py b/codeflash/code_utils/edit_generated_tests.py index 8e50b1d71..3ba824803 100644 --- a/codeflash/code_utils/edit_generated_tests.py +++ b/codeflash/code_utils/edit_generated_tests.py @@ -61,29 +61,52 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctio def _process_function_def_common(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None: self.context_stack.append(node.name) - i = len(node.body) - 1 + node_body = node.body + node_body_len = len(node_body) test_qualified_name = ".".join(self.context_stack) - key = test_qualified_name + "#" + str(self.abs_path) + + # Precompute key prefix + key_prefix = f"{test_qualified_name}#{self.abs_path}" + + i = node_body_len - 1 + orig_rt = self.original_runtimes + opt_rt = self.optimized_runtimes + get_comment = self.get_comment + + # Hoist isinstance tuple constants out of loop + compound_types = (ast.With, ast.For, ast.While, ast.If) + valid_types = (ast.stmt, ast.Assign) + while i >= 0: - line_node = node.body[i] - if isinstance(line_node, (ast.With, ast.For, ast.While, ast.If)): - j = len(line_node.body) - 1 + line_node = node_body[i] + if isinstance(line_node, compound_types): + line_node_body = line_node.body + j = len(line_node_body) - 1 + nodes_to_check_append = nodes_to_check_extend = None # Avoids local lookups while j >= 0: - compound_line_node: ast.stmt = line_node.body[j] - nodes_to_check = [compound_line_node] - nodes_to_check.extend(getattr(compound_line_node, "body", [])) - for internal_node in nodes_to_check: - if isinstance(internal_node, (ast.stmt, ast.Assign)): - inv_id = str(i) + "_" + str(j) - match_key = key + "#" + inv_id - if match_key in self.original_runtimes and match_key in self.optimized_runtimes: - self.results[internal_node.lineno] = self.get_comment(match_key) + compound_line_node: ast.stmt = line_node_body[j] + # Pre-extend only if there's a .body attribute, avoid repeated getattr cost + compound_line_node_body = getattr(compound_line_node, "body", None) + if compound_line_node_body: + nodes_to_check = [compound_line_node, *compound_line_node_body] + else: + nodes_to_check = [compound_line_node] + + inv_id = f"{i}_{j}" + match_key = f"{key_prefix}#{inv_id}" + + if match_key in orig_rt and match_key in opt_rt: + comment = get_comment(match_key) + # Avoid repeated isinstance - enumerate only actual stmt/Assign + for internal_node in nodes_to_check: + if isinstance(internal_node, valid_types): + self.results[internal_node.lineno] = comment j -= 1 else: inv_id = str(i) - match_key = key + "#" + inv_id - if match_key in self.original_runtimes and match_key in self.optimized_runtimes: - self.results[line_node.lineno] = self.get_comment(match_key) + match_key = f"{key_prefix}#{inv_id}" + if match_key in orig_rt and match_key in opt_rt: + self.results[line_node.lineno] = get_comment(match_key) i -= 1 self.context_stack.pop()