Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 28 additions & 18 deletions codeflash/code_utils/edit_generated_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,32 +60,42 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AsyncFunctio
return node

def _process_function_def_common(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
self.context_stack.append(node.name)
i = len(node.body) - 1
test_qualified_name = ".".join(self.context_stack)
key = test_qualified_name + "#" + str(self.abs_path)
# Optimize repeated attribute lookups and joins
context_stack = self.context_stack
key_base = ".".join(context_stack) + "#" + str(self.abs_path)
context_stack.append(node.name)
body = node.body
orig_runtimes = self.original_runtimes
opt_runtimes = self.optimized_runtimes
get_comment = self.get_comment
Comment on lines +64 to +70
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still don't like these


i = len(body) - 1

while i >= 0:
line_node = node.body[i]
line_node = body[i]
if isinstance(line_node, (ast.With, ast.For, ast.While, ast.If)):
j = len(line_node.body) - 1
while j >= 0:
compound_line_node: ast.stmt = line_node.body[j]
compound_body = line_node.body
compound_body_len = len(compound_body)
for j in range(compound_body_len - 1, -1, -1):
compound_line_node: ast.stmt = compound_body[j]
# Flatten nodes_to_check computation & avoid repeated getattr for body
nodes_to_check = [compound_line_node]
nodes_to_check.extend(getattr(compound_line_node, "body", []))
child_body = getattr(compound_line_node, "body", None)
if child_body:
nodes_to_check += child_body
for internal_node in nodes_to_check:
if isinstance(internal_node, (ast.stmt, ast.Assign)):
inv_id = str(i) + "_" + str(j)
match_key = key + "#" + inv_id
if match_key in self.original_runtimes and match_key in self.optimized_runtimes:
self.results[internal_node.lineno] = self.get_comment(match_key)
j -= 1
inv_id = f"{i}_{j}"
match_key = f"{key_base}#{inv_id}"
if match_key in orig_runtimes and match_key in opt_runtimes:
self.results[internal_node.lineno] = get_comment(match_key)
else:
inv_id = str(i)
match_key = key + "#" + inv_id
if match_key in self.original_runtimes and match_key in self.optimized_runtimes:
self.results[line_node.lineno] = self.get_comment(match_key)
match_key = f"{key_base}#{inv_id}"
if match_key in orig_runtimes and match_key in opt_runtimes:
self.results[line_node.lineno] = get_comment(match_key)
i -= 1
self.context_stack.pop()
context_stack.pop()


def get_fn_call_linenos(
Expand Down
Loading