Skip to content

Commit 1641147

Browse files
committed
Update code_extractor.py
1 parent 3ef3320 commit 1641147

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# ruff: noqa: ARG002
12
from __future__ import annotations
23

34
import ast
@@ -124,10 +125,11 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
124125
new_statements = list(updated_node.body)
125126

126127
# Find assignments to append
127-
assignments_to_append = []
128-
for name in self.new_assignment_order:
129-
if name not in self.processed_assignments and name in self.new_assignments:
130-
assignments_to_append.append(self.new_assignments[name])
128+
assignments_to_append = [
129+
self.new_assignments[name]
130+
for name in self.new_assignment_order
131+
if name not in self.processed_assignments and name in self.new_assignments
132+
]
131133

132134
if assignments_to_append:
133135
# Add a blank line before appending new assignments if needed
@@ -136,8 +138,12 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
136138
new_statements.pop() # Remove the Pass statement but keep the empty line
137139

138140
# Add the new assignments
139-
for assignment in assignments_to_append:
140-
new_statements.append(cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()]))
141+
new_statements.extend(
142+
[
143+
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
144+
for assignment in assignments_to_append
145+
]
146+
)
141147

142148
return updated_node.with_changes(body=new_statements)
143149

@@ -426,7 +432,7 @@ def find_target(node_list: list[ast.stmt], name_parts: tuple[str, str] | tuple[s
426432

427433
return find_target(target.body, name_parts[1:])
428434

429-
with open(file_path, encoding="utf8") as file:
435+
with file_path.open(encoding="utf8") as file:
430436
source_code: str = file.read()
431437
try:
432438
module_node: ast.Module = ast.parse(source_code)

0 commit comments

Comments
 (0)