Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 76 additions & 6 deletions codeflash/code_utils/code_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,64 @@ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
self.last_import_line = self.current_line


class ConditionalImportCollector(cst.CSTVisitor):
"""Collect imports inside top-level conditionals (e.g., if TYPE_CHECKING, try/except)."""

def __init__(self) -> None:
self.imports: set[str] = set()
self.depth = 0 # top-level

def get_full_dotted_name(self, expr: cst.BaseExpression) -> str:
if isinstance(expr, cst.Name):
return expr.value
if isinstance(expr, cst.Attribute):
return f"{self.get_full_dotted_name(expr.value)}.{expr.attr.value}"
return ""

def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None:
for statement in block.body:
if isinstance(statement, cst.SimpleStatementLine):
for child in statement.body:
if isinstance(child, cst.Import):
for alias in child.names:
module = self.get_full_dotted_name(alias.name)
asname = alias.asname.name.value if alias.asname else alias.name.value
self.imports.add(module if module == asname else f"{module}.{asname}")

elif isinstance(child, cst.ImportFrom):
if child.module is None:
continue
module = self.get_full_dotted_name(child.module)
for alias in child.names:
if isinstance(alias, cst.ImportAlias):
name = alias.name.value
asname = alias.asname.name.value if alias.asname else name
self.imports.add(f"{module}.{asname}")

def visit_Module(self, node: cst.Module) -> None:
self.depth = 0

def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
self.depth += 1

def leave_FunctionDef(self, node: cst.FunctionDef) -> None:
self.depth -= 1

def visit_ClassDef(self, node: cst.ClassDef) -> None:
self.depth += 1

def leave_ClassDef(self, node: cst.ClassDef) -> None:
self.depth -= 1

def visit_If(self, node: cst.If) -> None:
if self.depth == 0:
self._collect_imports_from_block(node.body)

def visit_Try(self, node: cst.Try) -> None:
if self.depth == 0:
self._collect_imports_from_block(node.body)


class ImportInserter(cst.CSTTransformer):
"""Transformer that inserts global statements after the last import."""

Expand Down Expand Up @@ -329,8 +387,19 @@ def add_needed_imports_from_module(
except Exception as e:
logger.error(f"Error parsing source module code: {e}")
return dst_module_code

cond_import_collector = ConditionalImportCollector()
try:
parsed_dst_module = cst.parse_module(dst_module_code)
parsed_dst_module.visit(cond_import_collector)
except cst.ParserSyntaxError as e:
logger.exception(f"Syntax error in destination module code: {e}")
return dst_module_code # Return the original code if there's a syntax error

try:
for mod in gatherer.module_imports:
if mod in cond_import_collector.imports:
continue
AddImportsVisitor.add_needed_import(dst_context, mod)
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
for mod, obj_seq in gatherer.object_mapping.items():
Expand All @@ -339,28 +408,29 @@ def add_needed_imports_from_module(
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
):
continue # Skip adding imports for helper functions already in the context
if f"{mod}.{obj}" in cond_import_collector.imports:
continue
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
except Exception as e:
logger.exception(f"Error adding imports to destination module code: {e}")
return dst_module_code
for mod, asname in gatherer.module_aliases.items():
if f"{mod}.{asname}" in cond_import_collector.imports:
continue
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname)
for mod, alias_pairs in gatherer.alias_mapping.items():
for alias_pair in alias_pairs:
if f"{mod}.{alias_pair[0]}" in helper_functions_fqn:
continue
if f"{mod}.{alias_pair[1]}" in cond_import_collector.imports:
continue
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])

try:
parsed_module = cst.parse_module(dst_module_code)
except cst.ParserSyntaxError as e:
logger.exception(f"Syntax error in destination module code: {e}")
return dst_module_code # Return the original code if there's a syntax error
try:
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_module)
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_dst_module)
transformed_module = RemoveImportsVisitor(dst_context).transform_module(transformed_module)
return transformed_module.code.lstrip("\n")
except Exception as e:
Expand Down
137 changes: 137 additions & 0 deletions tests/test_code_replacement.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
import re
import libcst as cst
from codeflash.code_utils.code_replacer import AutouseFixtureModifier, PytestMarkAdder, AddRequestArgument
import dataclasses
Expand Down Expand Up @@ -3091,3 +3092,139 @@ def my_fixture(request):
modified_module = module.visit(transformer)

assert modified_module.code.strip() == expected.strip()


def test_type_checking_imports():
optim_code = """from dataclasses import dataclass
from pydantic_ai.providers import Provider, infer_provider
from pydantic_ai_slim.pydantic_ai.models import Model
from pydantic_ai_slim.pydantic_ai.tools import ToolDefinition
from typing import Literal

#### problamatic imports ####
from huggingface_hub import AsyncInferenceClient, ChatCompletionInputTool
import requests
import aiohttp as aiohttp_
from math import pi as PI, sin as sine

@dataclass(init=False)
class HuggingFaceModel(Model):
def __init__(
self,
model_name: str,
*,
provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface',
):
print(requests.__name__)
print(aiohttp_.__name__)
print(PI)
print(sine)
# Fast branch: avoid repeating provider assignment
if isinstance(provider, str):
provider_obj = infer_provider(provider)
else:
provider_obj = provider
self._provider = provider
self._model_name = model_name
self.client = provider_obj.client

@staticmethod
def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool:
# Inline dict creation and single pass for possible strict attribute
tool_dict = {
'type': 'function',
'function': {
'name': f.name,
'description': f.description,
'parameters': f.parameters_json_schema,
},
}
if f.strict is not None:
tool_dict['function']['strict'] = f.strict
return ChatCompletionInputTool.parse_obj_as_instance(tool_dict) # type: ignore
"""

original_code = """from dataclasses import dataclass
from pydantic_ai.providers import Provider, infer_provider
from pydantic_ai_slim.pydantic_ai.models import Model
from pydantic_ai_slim.pydantic_ai.tools import ToolDefinition
from typing import Literal

try:
import aiohttp as aiohttp_
from math import pi as PI, sin as sine
from huggingface_hub import (
AsyncInferenceClient,
ChatCompletionInputMessage,
ChatCompletionInputMessageChunk,
ChatCompletionInputTool,
ChatCompletionInputToolCall,
ChatCompletionInputURL,
ChatCompletionOutput,
ChatCompletionOutputMessage,
ChatCompletionStreamOutput,
)
from huggingface_hub.errors import HfHubHTTPError

except ImportError as _import_error:
raise ImportError(
'Please install `huggingface_hub` to use Hugging Face Inference Providers, '
'you can use the `huggingface` optional group — `pip install "pydantic-ai-slim[huggingface]"`'
) from _import_error

if True:
import requests

__all__ = (
'HuggingFaceModel',
'HuggingFaceModelSettings',
)

@dataclass(init=False)
class HuggingFaceModel(Model):

def __init__(
self,
model_name: str,
*,
provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface',
):
self._model_name = model_name
self._provider = provider
if isinstance(provider, str):
provider = infer_provider(provider)
self.client = provider.client

@staticmethod
def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool:
tool_param: ChatCompletionInputTool = ChatCompletionInputTool.parse_obj_as_instance( # type: ignore
{
'type': 'function',
'function': {
'name': f.name,
'description': f.description,
'parameters': f.parameters_json_schema,
},
}
)
if f.strict is not None:
tool_param['function']['strict'] = f.strict
return tool_param
"""


function_name: str = "HuggingFaceModel._map_tool_definition"
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=[function_name],
optimized_code=optim_code,
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
project_root_path=Path(__file__).resolve().parent.resolve(),
)

assert not re.search(r"^import requests\b", new_code, re.MULTILINE) # conditional simple import: import <name>
assert not re.search(r"^import aiohttp as aiohttp_\b", new_code, re.MULTILINE) # conditional alias import: import <name> as <alias>
assert not re.search(r"^from math import pi as PI, sin as sine\b", new_code, re.MULTILINE) # conditional multiple aliases imports
assert "from huggingface_hub import AsyncInferenceClient, ChatCompletionInputTool" not in new_code # conditional from import
Loading