Skip to content

Commit c3b775f

Browse files
avoid adding new imports existed in top level try/catch or if TYPE_CHECKING
1 parent e806c5d commit c3b775f

File tree

2 files changed

+213
-6
lines changed

2 files changed

+213
-6
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,64 @@ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
195195
self.last_import_line = self.current_line
196196

197197

198+
class ConditionalImportCollector(cst.CSTVisitor):
199+
"""Collect imports inside top-level conditionals (e.g., if TYPE_CHECKING, try/except)."""
200+
201+
def __init__(self) -> None:
202+
self.imports: set[str] = set()
203+
self.depth = 0 # top-level
204+
205+
def get_full_dotted_name(self, expr: cst.BaseExpression) -> str:
206+
if isinstance(expr, cst.Name):
207+
return expr.value
208+
if isinstance(expr, cst.Attribute):
209+
return f"{self.get_full_dotted_name(expr.value)}.{expr.attr.value}"
210+
return ""
211+
212+
def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None:
213+
for statement in block.body:
214+
if isinstance(statement, cst.SimpleStatementLine):
215+
for child in statement.body:
216+
if isinstance(child, cst.Import):
217+
for alias in child.names:
218+
module = self.get_full_dotted_name(alias.name)
219+
asname = alias.asname.name.value if alias.asname else alias.name.value
220+
self.imports.add(module if module == asname else f"{module}.{asname}")
221+
222+
elif isinstance(child, cst.ImportFrom):
223+
if child.module is None:
224+
continue
225+
module = self.get_full_dotted_name(child.module)
226+
for alias in child.names:
227+
if isinstance(alias, cst.ImportAlias):
228+
name = alias.name.value
229+
asname = alias.asname.name.value if alias.asname else name
230+
self.imports.add(f"{module}.{asname}")
231+
232+
def visit_Module(self, node: cst.Module) -> None:
233+
self.depth = 0
234+
235+
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
236+
self.depth += 1
237+
238+
def leave_FunctionDef(self, node: cst.FunctionDef) -> None:
239+
self.depth -= 1
240+
241+
def visit_ClassDef(self, node: cst.ClassDef) -> None:
242+
self.depth += 1
243+
244+
def leave_ClassDef(self, node: cst.ClassDef) -> None:
245+
self.depth -= 1
246+
247+
def visit_If(self, node: cst.If) -> None:
248+
if self.depth == 0:
249+
self._collect_imports_from_block(node.body)
250+
251+
def visit_Try(self, node: cst.Try) -> None:
252+
if self.depth == 0:
253+
self._collect_imports_from_block(node.body)
254+
255+
198256
class ImportInserter(cst.CSTTransformer):
199257
"""Transformer that inserts global statements after the last import."""
200258

@@ -329,8 +387,19 @@ def add_needed_imports_from_module(
329387
except Exception as e:
330388
logger.error(f"Error parsing source module code: {e}")
331389
return dst_module_code
390+
391+
cond_import_collector = ConditionalImportCollector()
392+
try:
393+
parsed_dst_module = cst.parse_module(dst_module_code)
394+
parsed_dst_module.visit(cond_import_collector)
395+
except cst.ParserSyntaxError as e:
396+
logger.exception(f"Syntax error in destination module code: {e}")
397+
return dst_module_code # Return the original code if there's a syntax error
398+
332399
try:
333400
for mod in gatherer.module_imports:
401+
if mod in cond_import_collector.imports:
402+
continue
334403
AddImportsVisitor.add_needed_import(dst_context, mod)
335404
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
336405
for mod, obj_seq in gatherer.object_mapping.items():
@@ -339,28 +408,29 @@ def add_needed_imports_from_module(
339408
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
340409
):
341410
continue # Skip adding imports for helper functions already in the context
411+
if f"{mod}.{obj}" in cond_import_collector.imports:
412+
continue
342413
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
343414
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
344415
except Exception as e:
345416
logger.exception(f"Error adding imports to destination module code: {e}")
346417
return dst_module_code
347418
for mod, asname in gatherer.module_aliases.items():
419+
if f"{mod}.{asname}" in cond_import_collector.imports:
420+
continue
348421
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
349422
RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname)
350423
for mod, alias_pairs in gatherer.alias_mapping.items():
351424
for alias_pair in alias_pairs:
352425
if f"{mod}.{alias_pair[0]}" in helper_functions_fqn:
353426
continue
427+
if f"{mod}.{alias_pair[1]}" in cond_import_collector.imports:
428+
continue
354429
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
355430
RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
356431

357432
try:
358-
parsed_module = cst.parse_module(dst_module_code)
359-
except cst.ParserSyntaxError as e:
360-
logger.exception(f"Syntax error in destination module code: {e}")
361-
return dst_module_code # Return the original code if there's a syntax error
362-
try:
363-
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_module)
433+
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_dst_module)
364434
transformed_module = RemoveImportsVisitor(dst_context).transform_module(transformed_module)
365435
return transformed_module.code.lstrip("\n")
366436
except Exception as e:

tests/test_code_replacement.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import annotations
2+
import re
23
import libcst as cst
34
from codeflash.code_utils.code_replacer import AutouseFixtureModifier, PytestMarkAdder, AddRequestArgument
45
import dataclasses
@@ -3070,3 +3071,139 @@ def my_fixture(request):
30703071
modified_module = module.visit(transformer)
30713072

30723073
assert modified_module.code.strip() == expected.strip()
3074+
3075+
3076+
def test_type_checking_imports():
3077+
optim_code = """from dataclasses import dataclass
3078+
from pydantic_ai.providers import Provider, infer_provider
3079+
from pydantic_ai_slim.pydantic_ai.models import Model
3080+
from pydantic_ai_slim.pydantic_ai.tools import ToolDefinition
3081+
from typing import Literal
3082+
3083+
#### problamatic imports ####
3084+
from huggingface_hub import AsyncInferenceClient, ChatCompletionInputTool
3085+
import requests
3086+
import aiohttp as aiohttp_
3087+
from math import pi as PI, sin as sine
3088+
3089+
@dataclass(init=False)
3090+
class HuggingFaceModel(Model):
3091+
def __init__(
3092+
self,
3093+
model_name: str,
3094+
*,
3095+
provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface',
3096+
):
3097+
print(requests.__name__)
3098+
print(aiohttp_.__name__)
3099+
print(PI)
3100+
print(sine)
3101+
# Fast branch: avoid repeating provider assignment
3102+
if isinstance(provider, str):
3103+
provider_obj = infer_provider(provider)
3104+
else:
3105+
provider_obj = provider
3106+
self._provider = provider
3107+
self._model_name = model_name
3108+
self.client = provider_obj.client
3109+
3110+
@staticmethod
3111+
def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool:
3112+
# Inline dict creation and single pass for possible strict attribute
3113+
tool_dict = {
3114+
'type': 'function',
3115+
'function': {
3116+
'name': f.name,
3117+
'description': f.description,
3118+
'parameters': f.parameters_json_schema,
3119+
},
3120+
}
3121+
if f.strict is not None:
3122+
tool_dict['function']['strict'] = f.strict
3123+
return ChatCompletionInputTool.parse_obj_as_instance(tool_dict) # type: ignore
3124+
"""
3125+
3126+
original_code = """from dataclasses import dataclass
3127+
from pydantic_ai.providers import Provider, infer_provider
3128+
from pydantic_ai_slim.pydantic_ai.models import Model
3129+
from pydantic_ai_slim.pydantic_ai.tools import ToolDefinition
3130+
from typing import Literal
3131+
3132+
try:
3133+
import aiohttp as aiohttp_
3134+
from math import pi as PI, sin as sine
3135+
from huggingface_hub import (
3136+
AsyncInferenceClient,
3137+
ChatCompletionInputMessage,
3138+
ChatCompletionInputMessageChunk,
3139+
ChatCompletionInputTool,
3140+
ChatCompletionInputToolCall,
3141+
ChatCompletionInputURL,
3142+
ChatCompletionOutput,
3143+
ChatCompletionOutputMessage,
3144+
ChatCompletionStreamOutput,
3145+
)
3146+
from huggingface_hub.errors import HfHubHTTPError
3147+
3148+
except ImportError as _import_error:
3149+
raise ImportError(
3150+
'Please install `huggingface_hub` to use Hugging Face Inference Providers, '
3151+
'you can use the `huggingface` optional group — `pip install "pydantic-ai-slim[huggingface]"`'
3152+
) from _import_error
3153+
3154+
if True:
3155+
import requests
3156+
3157+
__all__ = (
3158+
'HuggingFaceModel',
3159+
'HuggingFaceModelSettings',
3160+
)
3161+
3162+
@dataclass(init=False)
3163+
class HuggingFaceModel(Model):
3164+
3165+
def __init__(
3166+
self,
3167+
model_name: str,
3168+
*,
3169+
provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface',
3170+
):
3171+
self._model_name = model_name
3172+
self._provider = provider
3173+
if isinstance(provider, str):
3174+
provider = infer_provider(provider)
3175+
self.client = provider.client
3176+
3177+
@staticmethod
3178+
def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool:
3179+
tool_param: ChatCompletionInputTool = ChatCompletionInputTool.parse_obj_as_instance( # type: ignore
3180+
{
3181+
'type': 'function',
3182+
'function': {
3183+
'name': f.name,
3184+
'description': f.description,
3185+
'parameters': f.parameters_json_schema,
3186+
},
3187+
}
3188+
)
3189+
if f.strict is not None:
3190+
tool_param['function']['strict'] = f.strict
3191+
return tool_param
3192+
"""
3193+
3194+
3195+
function_name: str = "HuggingFaceModel._map_tool_definition"
3196+
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
3197+
new_code: str = replace_functions_and_add_imports(
3198+
source_code=original_code,
3199+
function_names=[function_name],
3200+
optimized_code=optim_code,
3201+
module_abspath=Path(__file__).resolve(),
3202+
preexisting_objects=preexisting_objects,
3203+
project_root_path=Path(__file__).resolve().parent.resolve(),
3204+
)
3205+
3206+
assert not re.search(r"^import requests\b", new_code, re.MULTILINE) # conditional simple import: import <name>
3207+
assert not re.search(r"^import aiohttp as aiohttp_\b", new_code, re.MULTILINE) # conditional alias import: import <name> as <alias>
3208+
assert not re.search(r"^from math import pi as PI, sin as sine\b", new_code, re.MULTILINE) # conditional multiple aliases imports
3209+
assert "from huggingface_hub import AsyncInferenceClient, ChatCompletionInputTool" not in new_code # conditional from import

0 commit comments

Comments
 (0)