11from __future__ import annotations
22
33import ast
4- from pathlib import Path
5- from typing import TYPE_CHECKING , Dict , Optional , Set
4+ from typing import TYPE_CHECKING , Optional
65
76import libcst as cst
87import libcst .matchers as m
1110from libcst .helpers import calculate_module_and_package
1211
1312from codeflash .cli_cmds .console import logger
14- from codeflash .models .models import FunctionParent , FunctionSource
13+ from codeflash .models .models import FunctionParent
1514
1615if TYPE_CHECKING :
16+ from pathlib import Path
17+
1718 from libcst .helpers import ModuleNameAndPackage
1819
1920 from codeflash .discovery .functions_to_optimize import FunctionToOptimize
21+ from codeflash .models .models import FunctionSource
2022
21- from typing import List
2223
2324
2425class GlobalAssignmentCollector (cst .CSTVisitor ):
2526 """Collects all global assignment statements."""
2627
27- def __init__ (self ):
28+ def __init__ (self ) -> None :
2829 super ().__init__ ()
29- self .assignments : Dict [str , cst .Assign ] = {}
30- self .assignment_order : List [str ] = []
30+ self .assignments : dict [str , cst .Assign ] = {}
31+ self .assignment_order : list [str ] = []
3132 # Track scope depth to identify global assignments
3233 self .scope_depth = 0
3334 self .if_else_depth = 0
@@ -72,11 +73,11 @@ def visit_Assign(self, node: cst.Assign) -> Optional[bool]:
7273class GlobalAssignmentTransformer (cst .CSTTransformer ):
7374 """Transforms global assignments in the original file with those from the new file."""
7475
75- def __init__ (self , new_assignments : Dict [str , cst .Assign ], new_assignment_order : List [str ]):
76+ def __init__ (self , new_assignments : dict [str , cst .Assign ], new_assignment_order : list [str ]) -> None :
7677 super ().__init__ ()
7778 self .new_assignments = new_assignments
7879 self .new_assignment_order = new_assignment_order
79- self .processed_assignments : Set [str ] = set ()
80+ self .processed_assignments : set [str ] = set ()
8081 self .scope_depth = 0
8182 self .if_else_depth = 0
8283
@@ -145,7 +146,7 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
145146class GlobalStatementCollector (cst .CSTVisitor ):
146147 """Visitor that collects all global statements (excluding imports and functions/classes)."""
147148
148- def __init__ (self ):
149+ def __init__ (self ) -> None :
149150 super ().__init__ ()
150151 self .global_statements = []
151152 self .in_function_or_class = False
@@ -178,7 +179,7 @@ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
178179class LastImportFinder (cst .CSTVisitor ):
179180 """Finds the position of the last import statement in the module."""
180181
181- def __init__ (self ):
182+ def __init__ (self ) -> None :
182183 super ().__init__ ()
183184 self .last_import_line = 0
184185 self .current_line = 0
@@ -193,7 +194,7 @@ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
193194class ImportInserter (cst .CSTTransformer ):
194195 """Transformer that inserts global statements after the last import."""
195196
196- def __init__ (self , global_statements : List [cst .SimpleStatementLine ], last_import_line : int ):
197+ def __init__ (self , global_statements : list [cst .SimpleStatementLine ], last_import_line : int ) -> None :
197198 super ().__init__ ()
198199 self .global_statements = global_statements
199200 self .last_import_line = last_import_line
@@ -208,7 +209,7 @@ def leave_SimpleStatementLine(
208209 # If we're right after the last import and haven't inserted yet
209210 if self .current_line == self .last_import_line and not self .inserted :
210211 self .inserted = True
211- return cst .Module (body = [updated_node ] + self .global_statements )
212+ return cst .Module (body = [updated_node , * self .global_statements ] )
212213
213214 return cst .Module (body = [updated_node ])
214215
@@ -222,7 +223,7 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
222223 return updated_node
223224
224225
225- def extract_global_statements (source_code : str ) -> List [cst .SimpleStatementLine ]:
226+ def extract_global_statements (source_code : str ) -> list [cst .SimpleStatementLine ]:
226227 """Extract global statements from source code."""
227228 module = cst .parse_module (source_code )
228229 collector = GlobalStatementCollector ()
@@ -285,8 +286,7 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
285286 transformer = GlobalAssignmentTransformer (new_collector .assignments , new_collector .assignment_order )
286287 transformed_module = original_module .visit (transformer )
287288
288- dst_module_code = transformed_module .code
289- return dst_module_code
289+ return transformed_module .code
290290
291291
292292def add_needed_imports_from_module (
0 commit comments