Skip to content

Commit c5b1d30

Browse files
add global assignments after the imports
Signed-off-by: ali <[email protected]>
1 parent b77f50e commit c5b1d30

File tree

3 files changed

+47
-22
lines changed

3 files changed

+47
-22
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import annotations
33

44
import ast
5+
from itertools import chain
56
from typing import TYPE_CHECKING, Optional
67

78
import libcst as cst
@@ -119,6 +120,26 @@ def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> c
119120

120121
return updated_node
121122

123+
def _find_insertion_index(self, updated_node: cst.Module) -> int:
124+
insert_index = 0
125+
for i, stmt in enumerate(updated_node.body):
126+
is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any(
127+
isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body
128+
)
129+
130+
is_conditional_import = isinstance(stmt, cst.If) and all(
131+
isinstance(inner, cst.SimpleStatementLine)
132+
and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body)
133+
for inner in stmt.body.body
134+
)
135+
136+
if is_top_level_import or is_conditional_import:
137+
insert_index = i + 1
138+
else:
139+
# stop when we find the first non-import statement
140+
break
141+
return insert_index
142+
122143
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
123144
# Add any new assignments that weren't in the original file
124145
new_statements = list(updated_node.body)
@@ -131,18 +152,24 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
131152
]
132153

133154
if assignments_to_append:
134-
# Add a blank line before appending new assignments if needed
135-
if new_statements and not isinstance(new_statements[-1], cst.EmptyLine):
136-
new_statements.append(cst.SimpleStatementLine([cst.Pass()], leading_lines=[cst.EmptyLine()]))
137-
new_statements.pop() # Remove the Pass statement but keep the empty line
138-
139-
# Add the new assignments
140-
new_statements.extend(
141-
[
142-
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
143-
for assignment in assignments_to_append
144-
]
145-
)
155+
# after last top-level imports
156+
insert_index = self._find_insertion_index(updated_node)
157+
158+
assignment_lines = [
159+
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
160+
for assignment in assignments_to_append
161+
]
162+
163+
new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:]))
164+
165+
# Add a blank line after the last assignment if needed
166+
after_index = insert_index + len(assignment_lines)
167+
if after_index < len(new_statements):
168+
next_statement = new_statements[after_index]
169+
if not next_statement.leading_lines or not isinstance(next_statement.leading_lines[-1], cst.EmptyLine):
170+
new_statements[after_index] = next_statement.with_changes(
171+
leading_lines=[cst.EmptyLine(), *next_statement.leading_lines]
172+
)
146173

147174
return updated_node.with_changes(body=new_statements)
148175

tests/test_code_replacement.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2104,6 +2104,8 @@ def new_function2(value):
21042104
"""
21052105
expected_code = """import numpy as np
21062106
2107+
a = 6
2108+
21072109
print("Hello world")
21082110
if 2<3:
21092111
a=4
@@ -2126,8 +2128,6 @@ def __call__(self, value):
21262128
return "I am still old"
21272129
def new_function2(value):
21282130
return cst.ensure_type(value, str)
2129-
2130-
a = 6
21312131
"""
21322132
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
21332133
code_path.write_text(original_code, encoding="utf-8")

tests/test_multi_file_code_replacement.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def test_multi_file_replcement01() -> None:
1818
1919
from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent
2020
21+
_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+')
22+
2123
def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
2224
if not content:
2325
return 0
@@ -34,9 +36,6 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
3436
# TODO(Marcelo): We need to study how we can estimate the tokens for AudioUrl or ImageUrl.
3537
3638
return tokens
37-
38-
39-
_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+')
4039
""", encoding="utf-8")
4140

4241
main_file = (root_dir / "code_to_optimize/temp_main.py").resolve()
@@ -131,6 +130,10 @@ def _get_string_usage(text: str) -> Usage:
131130
132131
from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent
133132
133+
_translate_table = {ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'}
134+
135+
_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+')
136+
134137
def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
135138
if not content:
136139
return 0
@@ -155,11 +158,6 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
155158
tokens += len(part.data)
156159
157160
return tokens
158-
159-
160-
_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+')
161-
162-
_translate_table = {ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'}
163161
"""
164162

165163
assert new_code.rstrip() == original_main.rstrip() # No Change

0 commit comments

Comments
 (0)