Skip to content

Commit 3a2dfe7

Browse files
Merge pull request #179 from codeflash-ai/cf-616
Global assignments in optimization incorporated in replaced code (CF-616)
2 parents 1eb674d + 6db4199 commit 3a2dfe7

File tree

4 files changed

+812
-44
lines changed

4 files changed

+812
-44
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 254 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import ast
44
from pathlib import Path
5-
from typing import TYPE_CHECKING
5+
from typing import TYPE_CHECKING, Dict, Optional, Set
66

77
import libcst as cst
88
import libcst.matchers as m
@@ -18,6 +18,227 @@
1818

1919
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
2020

21+
from typing import List, Union
22+
23+
class GlobalAssignmentCollector(cst.CSTVisitor):
24+
"""Collects all global assignment statements."""
25+
26+
def __init__(self):
27+
super().__init__()
28+
self.assignments: Dict[str, cst.Assign] = {}
29+
self.assignment_order: List[str] = []
30+
# Track scope depth to identify global assignments
31+
self.scope_depth = 0
32+
self.if_else_depth = 0
33+
34+
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
35+
self.scope_depth += 1
36+
return True
37+
38+
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
39+
self.scope_depth -= 1
40+
41+
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
42+
self.scope_depth += 1
43+
return True
44+
45+
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
46+
self.scope_depth -= 1
47+
48+
def visit_If(self, node: cst.If) -> Optional[bool]:
49+
self.if_else_depth += 1
50+
return True
51+
52+
def leave_If(self, original_node: cst.If) -> None:
53+
self.if_else_depth -= 1
54+
55+
def visit_Else(self, node: cst.Else) -> Optional[bool]:
56+
# Else blocks are already counted as part of the if statement
57+
return True
58+
59+
def visit_Assign(self, node: cst.Assign) -> Optional[bool]:
60+
# Only process global assignments (not inside functions, classes, etc.)
61+
if self.scope_depth == 0 and self.if_else_depth == 0: # We're at module level
62+
for target in node.targets:
63+
if isinstance(target.target, cst.Name):
64+
name = target.target.value
65+
self.assignments[name] = node
66+
if name not in self.assignment_order:
67+
self.assignment_order.append(name)
68+
return True
69+
70+
71+
class GlobalAssignmentTransformer(cst.CSTTransformer):
72+
"""Transforms global assignments in the original file with those from the new file."""
73+
74+
def __init__(self, new_assignments: Dict[str, cst.Assign], new_assignment_order: List[str]):
75+
super().__init__()
76+
self.new_assignments = new_assignments
77+
self.new_assignment_order = new_assignment_order
78+
self.processed_assignments: Set[str] = set()
79+
self.scope_depth = 0
80+
self.if_else_depth = 0
81+
82+
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
83+
self.scope_depth += 1
84+
85+
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
86+
self.scope_depth -= 1
87+
return updated_node
88+
89+
def visit_ClassDef(self, node: cst.ClassDef) -> None:
90+
self.scope_depth += 1
91+
92+
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
93+
self.scope_depth -= 1
94+
return updated_node
95+
96+
def visit_If(self, node: cst.If) -> None:
97+
self.if_else_depth += 1
98+
99+
def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.If:
100+
self.if_else_depth -= 1
101+
return updated_node
102+
103+
def visit_Else(self, node: cst.Else) -> None:
104+
# Else blocks are already counted as part of the if statement
105+
pass
106+
107+
def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> cst.CSTNode:
108+
if self.scope_depth > 0 or self.if_else_depth > 0:
109+
return updated_node
110+
111+
# Check if this is a global assignment we need to replace
112+
for target in original_node.targets:
113+
if isinstance(target.target, cst.Name):
114+
name = target.target.value
115+
if name in self.new_assignments:
116+
self.processed_assignments.add(name)
117+
return self.new_assignments[name]
118+
119+
return updated_node
120+
121+
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
122+
# Add any new assignments that weren't in the original file
123+
new_statements = list(updated_node.body)
124+
125+
# Find assignments to append
126+
assignments_to_append = []
127+
for name in self.new_assignment_order:
128+
if name not in self.processed_assignments and name in self.new_assignments:
129+
assignments_to_append.append(self.new_assignments[name])
130+
131+
if assignments_to_append:
132+
# Add a blank line before appending new assignments if needed
133+
if new_statements and not isinstance(new_statements[-1], cst.EmptyLine):
134+
new_statements.append(cst.SimpleStatementLine([cst.Pass()], leading_lines=[cst.EmptyLine()]))
135+
new_statements.pop() # Remove the Pass statement but keep the empty line
136+
137+
# Add the new assignments
138+
for assignment in assignments_to_append:
139+
new_statements.append(
140+
cst.SimpleStatementLine(
141+
[assignment],
142+
leading_lines=[cst.EmptyLine()]
143+
)
144+
)
145+
146+
return updated_node.with_changes(body=new_statements)
147+
148+
class GlobalStatementCollector(cst.CSTVisitor):
149+
"""Visitor that collects all global statements (excluding imports and functions/classes)."""
150+
151+
def __init__(self):
152+
super().__init__()
153+
self.global_statements = []
154+
self.in_function_or_class = False
155+
156+
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
157+
# Don't visit inside classes
158+
self.in_function_or_class = True
159+
return False
160+
161+
def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
162+
self.in_function_or_class = False
163+
164+
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
165+
# Don't visit inside functions
166+
self.in_function_or_class = True
167+
return False
168+
169+
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
170+
self.in_function_or_class = False
171+
172+
def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
173+
if not self.in_function_or_class:
174+
for statement in node.body:
175+
# Skip imports
176+
if not isinstance(statement, (cst.Import, cst.ImportFrom, cst.Assign)):
177+
self.global_statements.append(node)
178+
break
179+
180+
181+
class LastImportFinder(cst.CSTVisitor):
182+
"""Finds the position of the last import statement in the module."""
183+
184+
def __init__(self):
185+
super().__init__()
186+
self.last_import_line = 0
187+
self.current_line = 0
188+
189+
def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
190+
self.current_line += 1
191+
for statement in node.body:
192+
if isinstance(statement, (cst.Import, cst.ImportFrom)):
193+
self.last_import_line = self.current_line
194+
195+
196+
class ImportInserter(cst.CSTTransformer):
197+
"""Transformer that inserts global statements after the last import."""
198+
199+
def __init__(self, global_statements: List[cst.SimpleStatementLine], last_import_line: int):
200+
super().__init__()
201+
self.global_statements = global_statements
202+
self.last_import_line = last_import_line
203+
self.current_line = 0
204+
self.inserted = False
205+
206+
def leave_SimpleStatementLine(
207+
self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine
208+
) -> cst.Module:
209+
self.current_line += 1
210+
211+
# If we're right after the last import and haven't inserted yet
212+
if self.current_line == self.last_import_line and not self.inserted:
213+
self.inserted = True
214+
return cst.Module(body=[updated_node] + self.global_statements)
215+
216+
return cst.Module(body=[updated_node])
217+
218+
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
219+
# If there were no imports, add at the beginning of the module
220+
if self.last_import_line == 0 and not self.inserted:
221+
updated_body = list(updated_node.body)
222+
for stmt in reversed(self.global_statements):
223+
updated_body.insert(0, stmt)
224+
return updated_node.with_changes(body=updated_body)
225+
return updated_node
226+
227+
228+
def extract_global_statements(source_code: str) -> List[cst.SimpleStatementLine]:
229+
"""Extract global statements from source code."""
230+
module = cst.parse_module(source_code)
231+
collector = GlobalStatementCollector()
232+
module.visit(collector)
233+
return collector.global_statements
234+
235+
236+
def find_last_import_line(target_code: str) -> int:
237+
"""Find the line number of the last import statement."""
238+
module = cst.parse_module(target_code)
239+
finder = LastImportFinder()
240+
module.visit(finder)
241+
return finder.last_import_line
21242

22243
class FutureAliasedImportTransformer(cst.CSTTransformer):
23244
def leave_ImportFrom(
@@ -38,6 +259,38 @@ def delete___future___aliased_imports(module_code: str) -> str:
38259
return cst.parse_module(module_code).visit(FutureAliasedImportTransformer()).code
39260

40261

262+
def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
263+
non_assignment_global_statements = extract_global_statements(src_module_code)
264+
265+
# Find the last import line in target
266+
last_import_line = find_last_import_line(dst_module_code)
267+
268+
# Parse the target code
269+
target_module = cst.parse_module(dst_module_code)
270+
271+
# Create transformer to insert non_assignment_global_statements
272+
transformer = ImportInserter(non_assignment_global_statements, last_import_line)
273+
#
274+
# # Apply transformation
275+
modified_module = target_module.visit(transformer)
276+
dst_module_code = modified_module.code
277+
278+
# Parse the code
279+
original_module = cst.parse_module(dst_module_code)
280+
new_module = cst.parse_module(src_module_code)
281+
282+
# Collect assignments from the new file
283+
new_collector = GlobalAssignmentCollector()
284+
new_module.visit(new_collector)
285+
286+
# Transform the original file
287+
transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order)
288+
transformed_module = original_module.visit(transformer)
289+
290+
dst_module_code = transformed_module.code
291+
return dst_module_code
292+
293+
41294
def add_needed_imports_from_module(
42295
src_module_code: str,
43296
dst_module_code: str,

codeflash/code_utils/code_replacer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import libcst as cst
99

1010
from codeflash.cli_cmds.console import logger
11-
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
11+
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, add_global_assignments
1212
from codeflash.models.models import FunctionParent
1313

1414
if TYPE_CHECKING:
@@ -220,7 +220,8 @@ def replace_function_definitions_in_module(
220220
)
221221
if is_zero_diff(source_code, new_code):
222222
return False
223-
module_abspath.write_text(new_code, encoding="utf8")
223+
code_with_global_assignments = add_global_assignments(optimized_code, new_code)
224+
module_abspath.write_text(code_with_global_assignments, encoding="utf8")
224225
return True
225226

226227

codeflash/context/code_context_extractor.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -360,23 +360,26 @@ def get_function_to_optimize_as_function_source(
360360

361361
# Find the name that matches our function
362362
for name in names:
363-
if (
364-
name.type == "function"
365-
and name.full_name
366-
and name.name == function_to_optimize.function_name
367-
and name.full_name.startswith(name.module_name)
368-
and get_qualified_name(name.module_name, name.full_name) == function_to_optimize.qualified_name
369-
):
370-
function_source = FunctionSource(
371-
file_path=function_to_optimize.file_path,
372-
qualified_name=function_to_optimize.qualified_name,
373-
fully_qualified_name=name.full_name,
374-
only_function_name=name.name,
375-
source_code=name.get_line_code(),
376-
jedi_definition=name,
377-
)
378-
return function_source
379-
363+
try:
364+
if (
365+
name.type == "function"
366+
and name.full_name
367+
and name.name == function_to_optimize.function_name
368+
and name.full_name.startswith(name.module_name)
369+
and get_qualified_name(name.module_name, name.full_name) == function_to_optimize.qualified_name
370+
):
371+
function_source = FunctionSource(
372+
file_path=function_to_optimize.file_path,
373+
qualified_name=function_to_optimize.qualified_name,
374+
fully_qualified_name=name.full_name,
375+
only_function_name=name.name,
376+
source_code=name.get_line_code(),
377+
jedi_definition=name,
378+
)
379+
return function_source
380+
except Exception as e:
381+
logger.exception(f"Error while getting function source: {e}")
382+
continue
380383
raise ValueError(
381384
f"Could not find function {function_to_optimize.function_name} in {function_to_optimize.file_path}"
382385
)

0 commit comments

Comments
 (0)