Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 47 additions & 12 deletions codeflash/code_utils/code_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import ast
from itertools import chain
from typing import TYPE_CHECKING, Optional

import libcst as cst
Expand Down Expand Up @@ -119,6 +120,32 @@ def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> c

return updated_node

def _find_insertion_index(self, updated_node: cst.Module) -> int:
"""Find the position of the last import statement in the top-level of the module."""
insert_index = 0
for i, stmt in enumerate(updated_node.body):
is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any(
isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body
)

is_conditional_import = isinstance(stmt, cst.If) and all(
isinstance(inner, cst.SimpleStatementLine)
and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body)
for inner in stmt.body.body
)

if is_top_level_import or is_conditional_import:
insert_index = i + 1

# Stop scanning once we reach a class or function definition.
# Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file.
# Without this check, a stray import later in the file
# would incorrectly shift our insertion index below actual code definitions.
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
break

return insert_index

def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
# Add any new assignments that weren't in the original file
new_statements = list(updated_node.body)
Expand All @@ -131,18 +158,26 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
]

if assignments_to_append:
# Add a blank line before appending new assignments if needed
if new_statements and not isinstance(new_statements[-1], cst.EmptyLine):
new_statements.append(cst.SimpleStatementLine([cst.Pass()], leading_lines=[cst.EmptyLine()]))
new_statements.pop() # Remove the Pass statement but keep the empty line

# Add the new assignments
new_statements.extend(
[
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
for assignment in assignments_to_append
]
)
# after last top-level imports
insert_index = self._find_insertion_index(updated_node)

assignment_lines = [
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
for assignment in assignments_to_append
]

new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:]))

# Add a blank line after the last assignment if needed
after_index = insert_index + len(assignment_lines)
if after_index < len(new_statements):
next_stmt = new_statements[after_index]
# If there's no empty line, add one
has_empty = any(isinstance(line, cst.EmptyLine) for line in next_stmt.leading_lines)
if not has_empty:
new_statements[after_index] = next_stmt.with_changes(
leading_lines=[cst.EmptyLine(), *next_stmt.leading_lines]
)

return updated_node.with_changes(body=new_statements)

Expand Down
229 changes: 227 additions & 2 deletions tests/test_code_replacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -2104,6 +2104,8 @@ def new_function2(value):
"""
expected_code = """import numpy as np

a = 6

print("Hello world")
if 2<3:
a=4
Expand All @@ -2126,8 +2128,6 @@ def __call__(self, value):
return "I am still old"
def new_function2(value):
return cst.ensure_type(value, str)

a = 6
"""
code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve()
code_path.write_text(original_code, encoding="utf-8")
Expand Down Expand Up @@ -3228,3 +3228,228 @@ def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool:
assert not re.search(r"^import aiohttp as aiohttp_\b", new_code, re.MULTILINE) # conditional alias import: import <name> as <alias>
assert not re.search(r"^from math import pi as PI, sin as sine\b", new_code, re.MULTILINE) # conditional multiple aliases imports
assert "from huggingface_hub import AsyncInferenceClient, ChatCompletionInputTool" not in new_code # conditional from import

def test_top_level_global_assignments() -> None:
root_dir = Path(__file__).parent.parent.resolve()
main_file = Path(root_dir / "code_to_optimize/temp_main.py").resolve()

original_code = '''"""
Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions.
"""

from typing import Any, Dict, List, Tuple

import structlog
from pydantic import BaseModel

from skyvern.forge import app
from skyvern.forge.sdk.prompting import PromptEngine
from skyvern.webeye.actions.actions import ActionType

LOG = structlog.get_logger(__name__)

# Initialize prompt engine
prompt_engine = PromptEngine("skyvern")


def hydrate_input_text_actions_with_field_names(
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
) -> Dict[str, List[Dict[str, Any]]]:
"""
Add field_name to input_text actions based on generated mappings.

Args:
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
field_mappings: Dictionary mapping "task_id:action_id" to field names

Returns:
Updated actions_by_task with field_name added to input_text actions
"""
updated_actions_by_task = {}

for task_id, actions in actions_by_task.items():
updated_actions = []

for action in actions:
action_copy = action.copy()

if action.get("action_type") == ActionType.INPUT_TEXT:
action_id = action.get("action_id", "")
mapping_key = f"{task_id}:{action_id}"

if mapping_key in field_mappings:
action_copy["field_name"] = field_mappings[mapping_key]
else:
# Fallback field name if mapping not found
intention = action.get("intention", "")
if intention:
# Simple field name generation from intention
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
field_name = "".join(c for c in field_name if c.isalnum() or c == "_")
action_copy["field_name"] = field_name or "unknown_field"
else:
action_copy["field_name"] = "unknown_field"

updated_actions.append(action_copy)

updated_actions_by_task[task_id] = updated_actions

return updated_actions_by_task
'''
main_file.write_text(original_code, encoding="utf-8")
optim_code = f'''```python:{main_file.relative_to(root_dir)}
from skyvern.webeye.actions.actions import ActionType
from typing import Any, Dict, List
import re

# Precompiled regex for efficiently generating simple field_name from intention
_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")

def hydrate_input_text_actions_with_field_names(
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
) -> Dict[str, List[Dict[str, Any]]]:
"""
Add field_name to input_text actions based on generated mappings.

Args:
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
field_mappings: Dictionary mapping "task_id:action_id" to field names

Returns:
Updated actions_by_task with field_name added to input_text actions
"""
updated_actions_by_task = {{}}

input_text_type = ActionType.INPUT_TEXT # local variable for faster access
intention_cleanup = _INTENTION_CLEANUP_RE

for task_id, actions in actions_by_task.items():
updated_actions = []

for action in actions:
action_copy = action.copy()

if action.get("action_type") == input_text_type:
action_id = action.get("action_id", "")
mapping_key = f"{{task_id}}:{{action_id}}"

if mapping_key in field_mappings:
action_copy["field_name"] = field_mappings[mapping_key]
else:
# Fallback field name if mapping not found
intention = action.get("intention", "")
if intention:
# Simple field name generation from intention
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
# Use compiled regex instead of "".join(c for ...)
field_name = intention_cleanup.sub("", field_name)
action_copy["field_name"] = field_name or "unknown_field"
else:
action_copy["field_name"] = "unknown_field"

updated_actions.append(action_copy)

updated_actions_by_task[task_id] = updated_actions

return updated_actions_by_task
```
'''
expected = '''"""
Module for generating GeneratedWorkflowParameters schema from workflow run input_text actions.
"""

from typing import Any, Dict, List, Tuple

import structlog
from pydantic import BaseModel

from skyvern.forge import app
from skyvern.forge.sdk.prompting import PromptEngine
from skyvern.webeye.actions.actions import ActionType
import re

_INTENTION_CLEANUP_RE = re.compile(r"[^a-zA-Z0-9_]+")

LOG = structlog.get_logger(__name__)

# Initialize prompt engine
prompt_engine = PromptEngine("skyvern")


def hydrate_input_text_actions_with_field_names(
actions_by_task: Dict[str, List[Dict[str, Any]]], field_mappings: Dict[str, str]
) -> Dict[str, List[Dict[str, Any]]]:
"""
Add field_name to input_text actions based on generated mappings.

Args:
actions_by_task: Dictionary mapping task IDs to lists of action dictionaries
field_mappings: Dictionary mapping "task_id:action_id" to field names

Returns:
Updated actions_by_task with field_name added to input_text actions
"""
updated_actions_by_task = {}

input_text_type = ActionType.INPUT_TEXT # local variable for faster access
intention_cleanup = _INTENTION_CLEANUP_RE

for task_id, actions in actions_by_task.items():
updated_actions = []

for action in actions:
action_copy = action.copy()

if action.get("action_type") == input_text_type:
action_id = action.get("action_id", "")
mapping_key = f"{task_id}:{action_id}"

if mapping_key in field_mappings:
action_copy["field_name"] = field_mappings[mapping_key]
else:
# Fallback field name if mapping not found
intention = action.get("intention", "")
if intention:
# Simple field name generation from intention
field_name = intention.lower().replace(" ", "_").replace("?", "").replace("'", "")
# Use compiled regex instead of "".join(c for ...)
field_name = intention_cleanup.sub("", field_name)
action_copy["field_name"] = field_name or "unknown_field"
else:
action_copy["field_name"] = "unknown_field"

updated_actions.append(action_copy)

updated_actions_by_task[task_id] = updated_actions

return updated_actions_by_task
'''

func = FunctionToOptimize(function_name="hydrate_input_text_actions_with_field_names", parents=[], file_path=main_file)
test_config = TestConfig(
tests_root=root_dir / "tests/pytest",
tests_project_rootdir=root_dir,
project_root_path=root_dir,
test_framework="pytest",
pytest_cmd="pytest",
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()

original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
for helper_function_path in helper_function_paths:
with helper_function_path.open(encoding="utf8") as f:
helper_code = f.read()
original_helper_code[helper_function_path] = helper_code

func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optim_code), original_helper_code=original_helper_code
)


new_code = main_file.read_text(encoding="utf-8")
main_file.unlink(missing_ok=True)

assert new_code == expected
14 changes: 6 additions & 8 deletions tests/test_multi_file_code_replacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def test_multi_file_replcement01() -> None:

from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent

_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+')

def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
if not content:
return 0
Expand All @@ -34,9 +36,6 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
# TODO(Marcelo): We need to study how we can estimate the tokens for AudioUrl or ImageUrl.

return tokens


_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+')
""", encoding="utf-8")

main_file = (root_dir / "code_to_optimize/temp_main.py").resolve()
Expand Down Expand Up @@ -131,6 +130,10 @@ def _get_string_usage(text: str) -> Usage:

from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent

_translate_table = {ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'}

_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+')

def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
if not content:
return 0
Expand All @@ -155,11 +158,6 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int:
tokens += len(part.data)

return tokens


_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+')

_translate_table = {ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'}
"""

assert new_code.rstrip() == original_main.rstrip() # No Change
Expand Down
Loading