1+ # ruff: noqa: ARG002
12from __future__ import annotations
23
34import ast
4- from pathlib import Path
5- from typing import TYPE_CHECKING , Dict , Optional , Set
5+ from typing import TYPE_CHECKING , Optional
66
77import libcst as cst
88import libcst .matchers as m
1111from libcst .helpers import calculate_module_and_package
1212
1313from codeflash .cli_cmds .console import logger
14- from codeflash .models .models import FunctionParent , FunctionSource
14+ from codeflash .models .models import FunctionParent
1515
1616if TYPE_CHECKING :
17+ from pathlib import Path
18+
1719 from libcst .helpers import ModuleNameAndPackage
1820
1921 from codeflash .discovery .functions_to_optimize import FunctionToOptimize
20-
21- from typing import List
22+ from codeflash .models .models import FunctionSource
2223
2324
2425class GlobalAssignmentCollector (cst .CSTVisitor ):
2526 """Collects all global assignment statements."""
2627
27- def __init__ (self ):
28+ def __init__ (self ) -> None :
2829 super ().__init__ ()
29- self .assignments : Dict [str , cst .Assign ] = {}
30- self .assignment_order : List [str ] = []
30+ self .assignments : dict [str , cst .Assign ] = {}
31+ self .assignment_order : list [str ] = []
3132 # Track scope depth to identify global assignments
3233 self .scope_depth = 0
3334 self .if_else_depth = 0
@@ -72,11 +73,11 @@ def visit_Assign(self, node: cst.Assign) -> Optional[bool]:
7273class GlobalAssignmentTransformer (cst .CSTTransformer ):
7374 """Transforms global assignments in the original file with those from the new file."""
7475
75- def __init__ (self , new_assignments : Dict [str , cst .Assign ], new_assignment_order : List [str ]):
76+ def __init__ (self , new_assignments : dict [str , cst .Assign ], new_assignment_order : list [str ]) -> None :
7677 super ().__init__ ()
7778 self .new_assignments = new_assignments
7879 self .new_assignment_order = new_assignment_order
79- self .processed_assignments : Set [str ] = set ()
80+ self .processed_assignments : set [str ] = set ()
8081 self .scope_depth = 0
8182 self .if_else_depth = 0
8283
@@ -124,10 +125,11 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
124125 new_statements = list (updated_node .body )
125126
126127 # Find assignments to append
127- assignments_to_append = []
128- for name in self .new_assignment_order :
129- if name not in self .processed_assignments and name in self .new_assignments :
130- assignments_to_append .append (self .new_assignments [name ])
128+ assignments_to_append = [
129+ self .new_assignments [name ]
130+ for name in self .new_assignment_order
131+ if name not in self .processed_assignments and name in self .new_assignments
132+ ]
131133
132134 if assignments_to_append :
133135 # Add a blank line before appending new assignments if needed
@@ -136,16 +138,20 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
136138 new_statements .pop () # Remove the Pass statement but keep the empty line
137139
138140 # Add the new assignments
139- for assignment in assignments_to_append :
140- new_statements .append (cst .SimpleStatementLine ([assignment ], leading_lines = [cst .EmptyLine ()]))
141+ new_statements .extend (
142+ [
143+ cst .SimpleStatementLine ([assignment ], leading_lines = [cst .EmptyLine ()])
144+ for assignment in assignments_to_append
145+ ]
146+ )
141147
142148 return updated_node .with_changes (body = new_statements )
143149
144150
145151class GlobalStatementCollector (cst .CSTVisitor ):
146152 """Visitor that collects all global statements (excluding imports and functions/classes)."""
147153
148- def __init__ (self ):
154+ def __init__ (self ) -> None :
149155 super ().__init__ ()
150156 self .global_statements = []
151157 self .in_function_or_class = False
@@ -178,7 +184,7 @@ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
178184class LastImportFinder (cst .CSTVisitor ):
179185 """Finds the position of the last import statement in the module."""
180186
181- def __init__ (self ):
187+ def __init__ (self ) -> None :
182188 super ().__init__ ()
183189 self .last_import_line = 0
184190 self .current_line = 0
@@ -193,7 +199,7 @@ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
193199class ImportInserter (cst .CSTTransformer ):
194200 """Transformer that inserts global statements after the last import."""
195201
196- def __init__ (self , global_statements : List [cst .SimpleStatementLine ], last_import_line : int ):
202+ def __init__ (self , global_statements : list [cst .SimpleStatementLine ], last_import_line : int ) -> None :
197203 super ().__init__ ()
198204 self .global_statements = global_statements
199205 self .last_import_line = last_import_line
@@ -208,7 +214,7 @@ def leave_SimpleStatementLine(
208214 # If we're right after the last import and haven't inserted yet
209215 if self .current_line == self .last_import_line and not self .inserted :
210216 self .inserted = True
211- return cst .Module (body = [updated_node ] + self .global_statements )
217+ return cst .Module (body = [updated_node , * self .global_statements ] )
212218
213219 return cst .Module (body = [updated_node ])
214220
@@ -222,7 +228,7 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
222228 return updated_node
223229
224230
225- def extract_global_statements (source_code : str ) -> List [cst .SimpleStatementLine ]:
231+ def extract_global_statements (source_code : str ) -> list [cst .SimpleStatementLine ]:
226232 """Extract global statements from source code."""
227233 module = cst .parse_module (source_code )
228234 collector = GlobalStatementCollector ()
@@ -285,8 +291,7 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
285291 transformer = GlobalAssignmentTransformer (new_collector .assignments , new_collector .assignment_order )
286292 transformed_module = original_module .visit (transformer )
287293
288- dst_module_code = transformed_module .code
289- return dst_module_code
294+ return transformed_module .code
290295
291296
292297def add_needed_imports_from_module (
@@ -357,9 +362,10 @@ def add_needed_imports_from_module(
357362
358363
359364def get_code (functions_to_optimize : list [FunctionToOptimize ]) -> tuple [str | None , set [tuple [str , str ]]]:
360- """Return the code for a function or methods in a Python module. functions_to_optimize is either a singleton
361- FunctionToOptimize instance, which represents either a function at the module level or a method of a class at the
362- module level, or it represents a list of methods of the same class.
365+ """Return the code for a function or methods in a Python module.
366+
367+ functions_to_optimize is either a singleton FunctionToOptimize instance, which represents either a function at the
368+ module level or a method of a class at the module level, or it represents a list of methods of the same class.
363369 """
364370 if (
365371 not functions_to_optimize
@@ -427,7 +433,7 @@ def find_target(node_list: list[ast.stmt], name_parts: tuple[str, str] | tuple[s
427433
428434 return find_target (target .body , name_parts [1 :])
429435
430- with open (file_path , encoding = "utf8" ) as file :
436+ with file_path . open (encoding = "utf8" ) as file :
431437 source_code : str = file .read ()
432438 try :
433439 module_node : ast .Module = ast .parse (source_code )
0 commit comments