Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 33 additions & 18 deletions codeflash/code_utils/edit_generated_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,32 +60,47 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctio
return node

def _process_function_def_common(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
self.context_stack.append(node.name)
i = len(node.body) - 1
test_qualified_name = ".".join(self.context_stack)
key = test_qualified_name + "#" + str(self.abs_path)
context_stack = self.context_stack
context_stack.append(node.name)
test_qualified_name = ".".join(context_stack)
key_base = f"{test_qualified_name}#{self.abs_path}"
results = self.results
original_runtimes = self.original_runtimes
optimized_runtimes = self.optimized_runtimes
get_comment = self.get_comment

# Pre-fetch these for loop, reduces attribute+dict lookup cost
node_body = node.body
i = len(node_body) - 1
while i >= 0:
line_node = node.body[i]
line_node = node_body[i]
if isinstance(line_node, (ast.With, ast.For, ast.While, ast.If)):
j = len(line_node.body) - 1
ln_body = line_node.body
j = len(ln_body) - 1
while j >= 0:
compound_line_node: ast.stmt = line_node.body[j]
compound_line_node: ast.stmt = ln_body[j]
# Collect nodes to check
nodes_to_check = [compound_line_node]
nodes_to_check.extend(getattr(compound_line_node, "body", []))
for internal_node in nodes_to_check:
if isinstance(internal_node, (ast.stmt, ast.Assign)):
inv_id = str(i) + "_" + str(j)
match_key = key + "#" + inv_id
if match_key in self.original_runtimes and match_key in self.optimized_runtimes:
self.results[internal_node.lineno] = self.get_comment(match_key)
extend_body = getattr(compound_line_node, "body", None)
if extend_body:
nodes_to_check.extend(extend_body)
inv_id = f"{i}_{j}"
match_key = f"{key_base}#{inv_id}"
if match_key in original_runtimes and match_key in optimized_runtimes:
# Slightly faster to avoid type checks in loop if possible
for internal_node in nodes_to_check:
# is ast.Assign a subclass of ast.stmt? If yes, only need ast.stmt (Assign inherits stmt).
# But original code checks for both, so preserve as-is.
if isinstance(internal_node, (ast.stmt, ast.Assign)):
results[internal_node.lineno] = get_comment(match_key)
j -= 1
else:
inv_id = str(i)
match_key = key + "#" + inv_id
if match_key in self.original_runtimes and match_key in self.optimized_runtimes:
self.results[line_node.lineno] = self.get_comment(match_key)
match_key = f"{key_base}#{inv_id}"
if match_key in original_runtimes and match_key in optimized_runtimes:
results[line_node.lineno] = get_comment(match_key)
i -= 1
self.context_stack.pop()
context_stack.pop()


def get_fn_call_linenos(
Expand Down
Loading