diff --git a/codeflash/code_utils/edit_generated_tests.py b/codeflash/code_utils/edit_generated_tests.py index 8e50b1d71..cf7317e75 100644 --- a/codeflash/code_utils/edit_generated_tests.py +++ b/codeflash/code_utils/edit_generated_tests.py @@ -60,32 +60,42 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctio return node def _process_function_def_common(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None: - self.context_stack.append(node.name) - i = len(node.body) - 1 - test_qualified_name = ".".join(self.context_stack) - key = test_qualified_name + "#" + str(self.abs_path) + # Optimize repeated attribute lookups and joins + context_stack = self.context_stack + key_base = ".".join(context_stack) + "#" + str(self.abs_path) + context_stack.append(node.name) + body = node.body + orig_runtimes = self.original_runtimes + opt_runtimes = self.optimized_runtimes + get_comment = self.get_comment + + i = len(body) - 1 + while i >= 0: - line_node = node.body[i] + line_node = body[i] if isinstance(line_node, (ast.With, ast.For, ast.While, ast.If)): - j = len(line_node.body) - 1 - while j >= 0: - compound_line_node: ast.stmt = line_node.body[j] + compound_body = line_node.body + compound_body_len = len(compound_body) + for j in range(compound_body_len - 1, -1, -1): + compound_line_node: ast.stmt = compound_body[j] + # Flatten nodes_to_check computation & avoid repeated getattr for body nodes_to_check = [compound_line_node] - nodes_to_check.extend(getattr(compound_line_node, "body", [])) + child_body = getattr(compound_line_node, "body", None) + if child_body: + nodes_to_check += child_body for internal_node in nodes_to_check: if isinstance(internal_node, (ast.stmt, ast.Assign)): - inv_id = str(i) + "_" + str(j) - match_key = key + "#" + inv_id - if match_key in self.original_runtimes and match_key in self.optimized_runtimes: - self.results[internal_node.lineno] = self.get_comment(match_key) - j -= 1 + inv_id = f"{i}_{j}" + match_key = f"{key_base}#{inv_id}" + if match_key in orig_runtimes and match_key in opt_runtimes: + self.results[internal_node.lineno] = get_comment(match_key) else: inv_id = str(i) - match_key = key + "#" + inv_id - if match_key in self.original_runtimes and match_key in self.optimized_runtimes: - self.results[line_node.lineno] = self.get_comment(match_key) + match_key = f"{key_base}#{inv_id}" + if match_key in orig_runtimes and match_key in opt_runtimes: + self.results[line_node.lineno] = get_comment(match_key) i -= 1 - self.context_stack.pop() + context_stack.pop() def get_fn_call_linenos(