Skip to content

Commit 24fb636

Browse files
Merge pull request #677 from codeflash-ai/fix/global-assignments-after-imports
[Enhancement] Add global assignments after imports
2 parents cccca40 + 628e004 commit 24fb636

File tree

3 files changed

+280
-22
lines changed

3 files changed

+280
-22
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import annotations
33

44
import ast
5+
from itertools import chain
56
from typing import TYPE_CHECKING, Optional
67

78
import libcst as cst
@@ -119,6 +120,32 @@ def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> c
119120

120121
return updated_node
121122

123+
def _find_insertion_index(self, updated_node: cst.Module) -> int:
124+
"""Find the position of the last import statement in the top-level of the module."""
125+
insert_index = 0
126+
for i, stmt in enumerate(updated_node.body):
127+
is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any(
128+
isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body
129+
)
130+
131+
is_conditional_import = isinstance(stmt, cst.If) and all(
132+
isinstance(inner, cst.SimpleStatementLine)
133+
and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body)
134+
for inner in stmt.body.body
135+
)
136+
137+
if is_top_level_import or is_conditional_import:
138+
insert_index = i + 1
139+
140+
# Stop scanning once we reach a class or function definition.
141+
# Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file.
142+
# Without this check, a stray import later in the file
143+
# would incorrectly shift our insertion index below actual code definitions.
144+
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
145+
break
146+
147+
return insert_index
148+
122149
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
123150
# Add any new assignments that weren't in the original file
124151
new_statements = list(updated_node.body)
@@ -131,18 +158,26 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
131158
]
132159

133160
if assignments_to_append:
134-
# Add a blank line before appending new assignments if needed
135-
if new_statements and not isinstance(new_statements[-1], cst.EmptyLine):
136-
new_statements.append(cst.SimpleStatementLine([cst.Pass()], leading_lines=[cst.EmptyLine()]))
137-
new_statements.pop() # Remove the Pass statement but keep the empty line
138-
139-
# Add the new assignments
140-
new_statements.extend(
141-
[
142-
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
143-
for assignment in assignments_to_append
144-
]
145-
)
161+
# after last top-level imports
162+
insert_index = self._find_insertion_index(updated_node)
163+
164+
assignment_lines = [
165+
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
166+
for assignment in assignments_to_append
167+
]
168+
169+
new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:]))
170+
171+
# Add a blank line after the last assignment if needed
172+
after_index = insert_index + len(assignment_lines)
173+
if after_index < len(new_statements):
174+
next_stmt = new_statements[after_index]
175+
# If there's no empty line, add one
176+
has_empty = any(isinstance(line, cst.EmptyLine) for line in next_stmt.leading_lines)
177+
if not has_empty:
178+
new_statements[after_index] = next_stmt.with_changes(
179+
leading_lines=[cst.EmptyLine(), *next_stmt.leading_lines]
180+
)
146181

147182
return updated_node.with_changes(body=new_statements)
148183

tests/test_code_replacement.py

Lines changed: 227 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2104,6 +2104,8 @@ def new_function2(value):
21042104
"""
21052105
expected_code = """import numpy as np
21062106
2107+
a = 6
2108+
21072109
print("Hello world")
21082110
if 2<3:
21092111
a=4
@@ -2126,8 +2128,6 @@ def __call__(self, value):
21262128
return "I am still old"
21272129
def new_function2(value):
21282130
return cst.ensure_type(value, str)
2129-
2130-
a = 6
21312131
"""
21322132
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
21332133
code_path.write_text(original_code, encoding="utf-8")
@@ -3228,3 +3228,228 @@ def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool:
32283228
assert not re.search(r"^import aiohttp as aiohttp_\b", new_code, re.MULTILINE) # conditional alias import: import <name> as <alias>
32293229
assert not re.search(r"^from math import pi as PI, sin as sine\b", new_code, re.MULTILINE) # conditional multiple aliases imports
32303230
assert "from huggingface_hub import AsyncInferenceClient, ChatCompletionInputTool" not in new_code # conditional from import
3231+
3232+
def test_top_level_global_assignments() -> None:
3233+
root_dir = Path(__file__).parent.parent.resolve()
3234+
main_file = Path(root_dir / "code_to_optimize/temp_main.py").resolve()
3235+
3236+
original_code = '''"""
3237+
Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions.
3238+
"""
3239+
3240+
from typing import Any, Dict, List, Tuple
3241+
3242+
import structlog
3243+
from pydantic import BaseModel
3244+
3245+
from skyvern.forge import app
3246+
from skyvern.forge.sdk.prompting import PromptEngine
3247+
from skyvern.webeye.actions.actions import ActionType
3248+
3249+
LOG = structlog.get_logger(__name__)
3250+
3251+
# Initialize prompt engine
3252+
prompt_engine = PromptEngine("skyvern")
3253+
3254+
3255+
def hydrate_input_text_actions_with_field_names(
3256+
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
3257+
) -> Dict[str, List[Dict[str, Any]]]:
3258+
"""
3259+
Add field_name to input_text actions based on generated mappings.
3260+
3261+
Args:
3262+
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
3263+
field_mappings: Dictionary mapping "task_id:action_id" to field names
3264+
3265+
Returns:
3266+
Updated actions_by_task with field_name added to input_text actions
3267+
"""
3268+
updated_actions_by_task = {}
3269+
3270+
for task_id, actions in actions_by_task.items():
3271+
updated_actions = []
3272+
3273+
for action in actions:
3274+
action_copy = action.copy()
3275+
3276+
if action.get("action_type") == ActionType.INPUT_TEXT:
3277+
action_id = action.get("action_id", "")
3278+
mapping_key = f"{task_id}:{action_id}"
3279+
3280+
if mapping_key in field_mappings:
3281+
action_copy["field_name"] = field_mappings[mapping_key]
3282+
else:
3283+
# Fallback field name if mapping not found
3284+
intention = action.get("intention", "")
3285+
if intention:
3286+
# Simple field name generation from intention
3287+
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
3288+
field_name = "".join(c for c in field_name if c.isalnum() or c == "_")
3289+
action_copy["field_name"] = field_name or "unknown_field"
3290+
else:
3291+
action_copy["field_name"] = "unknown_field"
3292+
3293+
updated_actions.append(action_copy)
3294+
3295+
updated_actions_by_task[task_id] = updated_actions
3296+
3297+
return updated_actions_by_task
3298+
'''
3299+
main_file.write_text(original_code, encoding="utf-8")
3300+
optim_code = f'''```python:{main_file.relative_to(root_dir)}
3301+
from skyvern.webeye.actions.actions import ActionType
3302+
from typing import Any, Dict, List
3303+
import re
3304+
3305+
# Precompiled regex for efficiently generating simple field_name from intention
3306+
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")
3307+
3308+
def hydrate_input_text_actions_with_field_names(
3309+
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
3310+
) -> Dict[str, List[Dict[str, Any]]]:
3311+
"""
3312+
Add field_name to input_text actions based on generated mappings.
3313+
3314+
Args:
3315+
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
3316+
field_mappings: Dictionary mapping "task_id:action_id" to field names
3317+
3318+
Returns:
3319+
Updated actions_by_task with field_name added to input_text actions
3320+
"""
3321+
updated_actions_by_task = {{}}
3322+
3323+
input_text_type = ActionType.INPUT_TEXT # local variable for faster access
3324+
intention_cleanup = _INTENTION_CLEANUP_RE
3325+
3326+
for task_id, actions in actions_by_task.items():
3327+
updated_actions = []
3328+
3329+
for action in actions:
3330+
action_copy = action.copy()
3331+
3332+
if action.get("action_type") == input_text_type:
3333+
action_id = action.get("action_id", "")
3334+
mapping_key = f"{{task_id}}:{{action_id}}"
3335+
3336+
if mapping_key in field_mappings:
3337+
action_copy["field_name"] = field_mappings[mapping_key]
3338+
else:
3339+
# Fallback field name if mapping not found
3340+
intention = action.get("intention", "")
3341+
if intention:
3342+
# Simple field name generation from intention
3343+
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
3344+
# Use compiled regex instead of "".join(c for ...)
3345+
field_name = intention_cleanup.sub("", field_name)
3346+
action_copy["field_name"] = field_name or "unknown_field"
3347+
else:
3348+
action_copy["field_name"] = "unknown_field"
3349+
3350+
updated_actions.append(action_copy)
3351+
3352+
updated_actions_by_task[task_id] = updated_actions
3353+
3354+
return updated_actions_by_task
3355+
```
3356+
'''
3357+
expected = '''"""
3358+
Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions.
3359+
"""
3360+
3361+
from typing import Any, Dict, List, Tuple
3362+
3363+
import structlog
3364+
from pydantic import BaseModel
3365+
3366+
from skyvern.forge import app
3367+
from skyvern.forge.sdk.prompting import PromptEngine
3368+
from skyvern.webeye.actions.actions import ActionType
3369+
import re
3370+
3371+
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")
3372+
3373+
LOG = structlog.get_logger(__name__)
3374+
3375+
# Initialize prompt engine
3376+
prompt_engine = PromptEngine("skyvern")
3377+
3378+
3379+
def hydrate_input_text_actions_with_field_names(
3380+
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
3381+
) -> Dict[str, List[Dict[str, Any]]]:
3382+
"""
3383+
Add field_name to input_text actions based on generated mappings.
3384+
3385+
Args:
3386+
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
3387+
field_mappings: Dictionary mapping "task_id:action_id" to field names
3388+
3389+
Returns:
3390+
Updated actions_by_task with field_name added to input_text actions
3391+
"""
3392+
updated_actions_by_task = {}
3393+
3394+
input_text_type = ActionType.INPUT_TEXT # local variable for faster access
3395+
intention_cleanup = _INTENTION_CLEANUP_RE
3396+
3397+
for task_id, actions in actions_by_task.items():
3398+
updated_actions = []
3399+
3400+
for action in actions:
3401+
action_copy = action.copy()
3402+
3403+
if action.get("action_type") == input_text_type:
3404+
action_id = action.get("action_id", "")
3405+
mapping_key = f"{task_id}:{action_id}"
3406+
3407+
if mapping_key in field_mappings:
3408+
action_copy["field_name"] = field_mappings[mapping_key]
3409+
else:
3410+
# Fallback field name if mapping not found
3411+
intention = action.get("intention", "")
3412+
if intention:
3413+
# Simple field name generation from intention
3414+
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
3415+
# Use compiled regex instead of "".join(c for ...)
3416+
field_name = intention_cleanup.sub("", field_name)
3417+
action_copy["field_name"] = field_name or "unknown_field"
3418+
else:
3419+
action_copy["field_name"] = "unknown_field"
3420+
3421+
updated_actions.append(action_copy)
3422+
3423+
updated_actions_by_task[task_id] = updated_actions
3424+
3425+
return updated_actions_by_task
3426+
'''
3427+
3428+
func = FunctionToOptimize(function_name="hydrate_input_text_actions_with_field_names", parents=[], file_path=main_file)
3429+
test_config = TestConfig(
3430+
tests_root=root_dir / "tests/pytest",
3431+
tests_project_rootdir=root_dir,
3432+
project_root_path=root_dir,
3433+
test_framework="pytest",
3434+
pytest_cmd="pytest",
3435+
)
3436+
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
3437+
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
3438+
3439+
original_helper_code: dict[Path, str] = {}
3440+
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
3441+
for helper_function_path in helper_function_paths:
3442+
with helper_function_path.open(encoding="utf8") as f:
3443+
helper_code = f.read()
3444+
original_helper_code[helper_function_path] = helper_code
3445+
3446+
func_optimizer.args = Args()
3447+
func_optimizer.replace_function_and_helpers_with_optimized_code(
3448+
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optim_code), original_helper_code=original_helper_code
3449+
)
3450+
3451+
3452+
new_code = main_file.read_text(encoding="utf-8")
3453+
main_file.unlink(missing_ok=True)
3454+
3455+
assert new_code == expected

tests/test_multi_file_code_replacement.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ def test_multi_file_replcement01() -> None:
1818
1919
from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent
2020
21+
_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+')
22+
2123
def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
2224
if not content:
2325
return 0
@@ -34,9 +36,6 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
3436
# TODO(Marcelo): We need to study how we can estimate the tokens for AudioUrl or ImageUrl.
3537
3638
return tokens
37-
38-
39-
_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+')
4039
""", encoding="utf-8")
4140

4241
main_file = (root_dir / "code_to_optimize/temp_main.py").resolve()
@@ -131,6 +130,10 @@ def _get_string_usage(text: str) -> Usage:
131130
132131
from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent
133132
133+
_translate_table = {ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'}
134+
135+
_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+')
136+
134137
def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
135138
if not content:
136139
return 0
@@ -155,11 +158,6 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
155158
tokens += len(part.data)
156159
157160
return tokens
158-
159-
160-
_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+')
161-
162-
_translate_table = {ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'}
163161
"""
164162

165163
assert new_code.rstrip() == original_main.rstrip() # No Change

0 commit comments

Comments
 (0)