22
33import ast
44from pathlib import Path
5- from typing import TYPE_CHECKING
5+ from typing import TYPE_CHECKING , Dict , Optional , Set
66
77import libcst as cst
88import libcst .matchers as m
1818
1919 from codeflash .discovery .functions_to_optimize import FunctionToOptimize
2020
21+ from typing import List , Union
22+
23+ class GlobalAssignmentCollector (cst .CSTVisitor ):
24+ """Collects all global assignment statements."""
25+
26+ def __init__ (self ):
27+ super ().__init__ ()
28+ self .assignments : Dict [str , cst .Assign ] = {}
29+ self .assignment_order : List [str ] = []
30+ # Track scope depth to identify global assignments
31+ self .scope_depth = 0
32+ self .if_else_depth = 0
33+
34+ def visit_FunctionDef (self , node : cst .FunctionDef ) -> Optional [bool ]:
35+ self .scope_depth += 1
36+ return True
37+
38+ def leave_FunctionDef (self , original_node : cst .FunctionDef ) -> None :
39+ self .scope_depth -= 1
40+
41+ def visit_ClassDef (self , node : cst .ClassDef ) -> Optional [bool ]:
42+ self .scope_depth += 1
43+ return True
44+
45+ def leave_ClassDef (self , original_node : cst .ClassDef ) -> None :
46+ self .scope_depth -= 1
47+
48+ def visit_If (self , node : cst .If ) -> Optional [bool ]:
49+ self .if_else_depth += 1
50+ return True
51+
52+ def leave_If (self , original_node : cst .If ) -> None :
53+ self .if_else_depth -= 1
54+
55+ def visit_Else (self , node : cst .Else ) -> Optional [bool ]:
56+ # Else blocks are already counted as part of the if statement
57+ return True
58+
59+ def visit_Assign (self , node : cst .Assign ) -> Optional [bool ]:
60+ # Only process global assignments (not inside functions, classes, etc.)
61+ if self .scope_depth == 0 and self .if_else_depth == 0 : # We're at module level
62+ for target in node .targets :
63+ if isinstance (target .target , cst .Name ):
64+ name = target .target .value
65+ self .assignments [name ] = node
66+ if name not in self .assignment_order :
67+ self .assignment_order .append (name )
68+ return True
69+
70+
71+ class GlobalAssignmentTransformer (cst .CSTTransformer ):
72+ """Transforms global assignments in the original file with those from the new file."""
73+
74+ def __init__ (self , new_assignments : Dict [str , cst .Assign ], new_assignment_order : List [str ]):
75+ super ().__init__ ()
76+ self .new_assignments = new_assignments
77+ self .new_assignment_order = new_assignment_order
78+ self .processed_assignments : Set [str ] = set ()
79+ self .scope_depth = 0
80+ self .if_else_depth = 0
81+
82+ def visit_FunctionDef (self , node : cst .FunctionDef ) -> None :
83+ self .scope_depth += 1
84+
85+ def leave_FunctionDef (self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef ) -> cst .FunctionDef :
86+ self .scope_depth -= 1
87+ return updated_node
88+
89+ def visit_ClassDef (self , node : cst .ClassDef ) -> None :
90+ self .scope_depth += 1
91+
92+ def leave_ClassDef (self , original_node : cst .ClassDef , updated_node : cst .ClassDef ) -> cst .ClassDef :
93+ self .scope_depth -= 1
94+ return updated_node
95+
96+ def visit_If (self , node : cst .If ) -> None :
97+ self .if_else_depth += 1
98+
99+ def leave_If (self , original_node : cst .If , updated_node : cst .If ) -> cst .If :
100+ self .if_else_depth -= 1
101+ return updated_node
102+
103+ def visit_Else (self , node : cst .Else ) -> None :
104+ # Else blocks are already counted as part of the if statement
105+ pass
106+
107+ def leave_Assign (self , original_node : cst .Assign , updated_node : cst .Assign ) -> cst .CSTNode :
108+ if self .scope_depth > 0 or self .if_else_depth > 0 :
109+ return updated_node
110+
111+ # Check if this is a global assignment we need to replace
112+ for target in original_node .targets :
113+ if isinstance (target .target , cst .Name ):
114+ name = target .target .value
115+ if name in self .new_assignments :
116+ self .processed_assignments .add (name )
117+ return self .new_assignments [name ]
118+
119+ return updated_node
120+
121+ def leave_Module (self , original_node : cst .Module , updated_node : cst .Module ) -> cst .Module :
122+ # Add any new assignments that weren't in the original file
123+ new_statements = list (updated_node .body )
124+
125+ # Find assignments to append
126+ assignments_to_append = []
127+ for name in self .new_assignment_order :
128+ if name not in self .processed_assignments and name in self .new_assignments :
129+ assignments_to_append .append (self .new_assignments [name ])
130+
131+ if assignments_to_append :
132+ # Add a blank line before appending new assignments if needed
133+ if new_statements and not isinstance (new_statements [- 1 ], cst .EmptyLine ):
134+ new_statements .append (cst .SimpleStatementLine ([cst .Pass ()], leading_lines = [cst .EmptyLine ()]))
135+ new_statements .pop () # Remove the Pass statement but keep the empty line
136+
137+ # Add the new assignments
138+ for assignment in assignments_to_append :
139+ new_statements .append (
140+ cst .SimpleStatementLine (
141+ [assignment ],
142+ leading_lines = [cst .EmptyLine ()]
143+ )
144+ )
145+
146+ return updated_node .with_changes (body = new_statements )
147+
148+ class GlobalStatementCollector (cst .CSTVisitor ):
149+ """Visitor that collects all global statements (excluding imports and functions/classes)."""
150+
151+ def __init__ (self ):
152+ super ().__init__ ()
153+ self .global_statements = []
154+ self .in_function_or_class = False
155+
156+ def visit_ClassDef (self , node : cst .ClassDef ) -> bool :
157+ # Don't visit inside classes
158+ self .in_function_or_class = True
159+ return False
160+
161+ def leave_ClassDef (self , original_node : cst .ClassDef ) -> None :
162+ self .in_function_or_class = False
163+
164+ def visit_FunctionDef (self , node : cst .FunctionDef ) -> bool :
165+ # Don't visit inside functions
166+ self .in_function_or_class = True
167+ return False
168+
169+ def leave_FunctionDef (self , original_node : cst .FunctionDef ) -> None :
170+ self .in_function_or_class = False
171+
172+ def visit_SimpleStatementLine (self , node : cst .SimpleStatementLine ) -> None :
173+ if not self .in_function_or_class :
174+ for statement in node .body :
175+ # Skip imports
176+ if not isinstance (statement , (cst .Import , cst .ImportFrom , cst .Assign )):
177+ self .global_statements .append (node )
178+ break
179+
180+
181+ class LastImportFinder (cst .CSTVisitor ):
182+ """Finds the position of the last import statement in the module."""
183+
184+ def __init__ (self ):
185+ super ().__init__ ()
186+ self .last_import_line = 0
187+ self .current_line = 0
188+
189+ def visit_SimpleStatementLine (self , node : cst .SimpleStatementLine ) -> None :
190+ self .current_line += 1
191+ for statement in node .body :
192+ if isinstance (statement , (cst .Import , cst .ImportFrom )):
193+ self .last_import_line = self .current_line
194+
195+
196+ class ImportInserter (cst .CSTTransformer ):
197+ """Transformer that inserts global statements after the last import."""
198+
199+ def __init__ (self , global_statements : List [cst .SimpleStatementLine ], last_import_line : int ):
200+ super ().__init__ ()
201+ self .global_statements = global_statements
202+ self .last_import_line = last_import_line
203+ self .current_line = 0
204+ self .inserted = False
205+
206+ def leave_SimpleStatementLine (
207+ self , original_node : cst .SimpleStatementLine , updated_node : cst .SimpleStatementLine
208+ ) -> cst .Module :
209+ self .current_line += 1
210+
211+ # If we're right after the last import and haven't inserted yet
212+ if self .current_line == self .last_import_line and not self .inserted :
213+ self .inserted = True
214+ return cst .Module (body = [updated_node ] + self .global_statements )
215+
216+ return cst .Module (body = [updated_node ])
217+
218+ def leave_Module (self , original_node : cst .Module , updated_node : cst .Module ) -> cst .Module :
219+ # If there were no imports, add at the beginning of the module
220+ if self .last_import_line == 0 and not self .inserted :
221+ updated_body = list (updated_node .body )
222+ for stmt in reversed (self .global_statements ):
223+ updated_body .insert (0 , stmt )
224+ return updated_node .with_changes (body = updated_body )
225+ return updated_node
226+
227+
228+ def extract_global_statements (source_code : str ) -> List [cst .SimpleStatementLine ]:
229+ """Extract global statements from source code."""
230+ module = cst .parse_module (source_code )
231+ collector = GlobalStatementCollector ()
232+ module .visit (collector )
233+ return collector .global_statements
234+
235+
236+ def find_last_import_line (target_code : str ) -> int :
237+ """Find the line number of the last import statement."""
238+ module = cst .parse_module (target_code )
239+ finder = LastImportFinder ()
240+ module .visit (finder )
241+ return finder .last_import_line
21242
22243class FutureAliasedImportTransformer (cst .CSTTransformer ):
23244 def leave_ImportFrom (
@@ -38,6 +259,38 @@ def delete___future___aliased_imports(module_code: str) -> str:
38259 return cst .parse_module (module_code ).visit (FutureAliasedImportTransformer ()).code
39260
40261
262+ def add_global_assignments (src_module_code : str , dst_module_code : str ) -> str :
263+ non_assignment_global_statements = extract_global_statements (src_module_code )
264+
265+ # Find the last import line in target
266+ last_import_line = find_last_import_line (dst_module_code )
267+
268+ # Parse the target code
269+ target_module = cst .parse_module (dst_module_code )
270+
271+ # Create transformer to insert non_assignment_global_statements
272+ transformer = ImportInserter (non_assignment_global_statements , last_import_line )
273+ #
274+ # # Apply transformation
275+ modified_module = target_module .visit (transformer )
276+ dst_module_code = modified_module .code
277+
278+ # Parse the code
279+ original_module = cst .parse_module (dst_module_code )
280+ new_module = cst .parse_module (src_module_code )
281+
282+ # Collect assignments from the new file
283+ new_collector = GlobalAssignmentCollector ()
284+ new_module .visit (new_collector )
285+
286+ # Transform the original file
287+ transformer = GlobalAssignmentTransformer (new_collector .assignments , new_collector .assignment_order )
288+ transformed_module = original_module .visit (transformer )
289+
290+ dst_module_code = transformed_module .code
291+ return dst_module_code
292+
293+
41294def add_needed_imports_from_module (
42295 src_module_code : str ,
43296 dst_module_code : str ,
0 commit comments