Skip to content

Commit f7adce8

Browse files
Merge pull request #600 from codeflash-ai/fix/existing-top-level-cond-imports
[FIX] Respect top-level conditional imports e.g. (Conditions like TYPE_CHECKING & try/except &) (CF-382)
2 parents ad7dbe4 + 6ef3ec4 commit f7adce8

File tree

2 files changed

+213
-6
lines changed

2 files changed

+213
-6
lines changed

codeflash/code_utils/code_extractor.py

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,64 @@ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
195195
self.last_import_line = self.current_line
196196

197197

198+
class ConditionalImportCollector(cst.CSTVisitor):
199+
"""Collect imports inside top-level conditionals (e.g., if TYPE_CHECKING, try/except)."""
200+
201+
def __init__(self) -> None:
202+
self.imports: set[str] = set()
203+
self.depth = 0 # top-level
204+
205+
def get_full_dotted_name(self, expr: cst.BaseExpression) -> str:
206+
if isinstance(expr, cst.Name):
207+
return expr.value
208+
if isinstance(expr, cst.Attribute):
209+
return f"{self.get_full_dotted_name(expr.value)}.{expr.attr.value}"
210+
return ""
211+
212+
def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None:
213+
for statement in block.body:
214+
if isinstance(statement, cst.SimpleStatementLine):
215+
for child in statement.body:
216+
if isinstance(child, cst.Import):
217+
for alias in child.names:
218+
module = self.get_full_dotted_name(alias.name)
219+
asname = alias.asname.name.value if alias.asname else alias.name.value
220+
self.imports.add(module if module == asname else f"{module}.{asname}")
221+
222+
elif isinstance(child, cst.ImportFrom):
223+
if child.module is None:
224+
continue
225+
module = self.get_full_dotted_name(child.module)
226+
for alias in child.names:
227+
if isinstance(alias, cst.ImportAlias):
228+
name = alias.name.value
229+
asname = alias.asname.name.value if alias.asname else name
230+
self.imports.add(f"{module}.{asname}")
231+
232+
def visit_Module(self, node: cst.Module) -> None:
233+
self.depth = 0
234+
235+
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
236+
self.depth += 1
237+
238+
def leave_FunctionDef(self, node: cst.FunctionDef) -> None:
239+
self.depth -= 1
240+
241+
def visit_ClassDef(self, node: cst.ClassDef) -> None:
242+
self.depth += 1
243+
244+
def leave_ClassDef(self, node: cst.ClassDef) -> None:
245+
self.depth -= 1
246+
247+
def visit_If(self, node: cst.If) -> None:
248+
if self.depth == 0:
249+
self._collect_imports_from_block(node.body)
250+
251+
def visit_Try(self, node: cst.Try) -> None:
252+
if self.depth == 0:
253+
self._collect_imports_from_block(node.body)
254+
255+
198256
class ImportInserter(cst.CSTTransformer):
199257
"""Transformer that inserts global statements after the last import."""
200258

@@ -329,8 +387,19 @@ def add_needed_imports_from_module(
329387
except Exception as e:
330388
logger.error(f"Error parsing source module code: {e}")
331389
return dst_module_code
390+
391+
cond_import_collector = ConditionalImportCollector()
392+
try:
393+
parsed_dst_module = cst.parse_module(dst_module_code)
394+
parsed_dst_module.visit(cond_import_collector)
395+
except cst.ParserSyntaxError as e:
396+
logger.exception(f"Syntax error in destination module code: {e}")
397+
return dst_module_code # Return the original code if there's a syntax error
398+
332399
try:
333400
for mod in gatherer.module_imports:
401+
if mod in cond_import_collector.imports:
402+
continue
334403
AddImportsVisitor.add_needed_import(dst_context, mod)
335404
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
336405
for mod, obj_seq in gatherer.object_mapping.items():
@@ -339,28 +408,29 @@ def add_needed_imports_from_module(
339408
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
340409
):
341410
continue # Skip adding imports for helper functions already in the context
411+
if f"{mod}.{obj}" in cond_import_collector.imports:
412+
continue
342413
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
343414
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
344415
except Exception as e:
345416
logger.exception(f"Error adding imports to destination module code: {e}")
346417
return dst_module_code
347418
for mod, asname in gatherer.module_aliases.items():
419+
if f"{mod}.{asname}" in cond_import_collector.imports:
420+
continue
348421
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
349422
RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname)
350423
for mod, alias_pairs in gatherer.alias_mapping.items():
351424
for alias_pair in alias_pairs:
352425
if f"{mod}.{alias_pair[0]}" in helper_functions_fqn:
353426
continue
427+
if f"{mod}.{alias_pair[1]}" in cond_import_collector.imports:
428+
continue
354429
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
355430
RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
356431

357432
try:
358-
parsed_module = cst.parse_module(dst_module_code)
359-
except cst.ParserSyntaxError as e:
360-
logger.exception(f"Syntax error in destination module code: {e}")
361-
return dst_module_code # Return the original code if there's a syntax error
362-
try:
363-
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_module)
433+
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_dst_module)
364434
transformed_module = RemoveImportsVisitor(dst_context).transform_module(transformed_module)
365435
return transformed_module.code.lstrip("\n")
366436
except Exception as e:

tests/test_code_replacement.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import annotations
2+
import re
23
import libcst as cst
34
from codeflash.code_utils.code_replacer import AutouseFixtureModifier, PytestMarkAdder, AddRequestArgument
45
import dataclasses
@@ -3091,3 +3092,139 @@ def my_fixture(request):
30913092
modified_module = module.visit(transformer)
30923093

30933094
assert modified_module.code.strip() == expected.strip()
3095+
3096+
3097+
def test_type_checking_imports():
3098+
optim_code = """from dataclasses import dataclass
3099+
from pydantic_ai.providers import Provider, infer_provider
3100+
from pydantic_ai_slim.pydantic_ai.models import Model
3101+
from pydantic_ai_slim.pydantic_ai.tools import ToolDefinition
3102+
from typing import Literal
3103+
3104+
#### problamatic imports ####
3105+
from huggingface_hub import AsyncInferenceClient, ChatCompletionInputTool
3106+
import requests
3107+
import aiohttp as aiohttp_
3108+
from math import pi as PI, sin as sine
3109+
3110+
@dataclass(init=False)
3111+
class HuggingFaceModel(Model):
3112+
def __init__(
3113+
self,
3114+
model_name: str,
3115+
*,
3116+
provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface',
3117+
):
3118+
print(requests.__name__)
3119+
print(aiohttp_.__name__)
3120+
print(PI)
3121+
print(sine)
3122+
# Fast branch: avoid repeating provider assignment
3123+
if isinstance(provider, str):
3124+
provider_obj = infer_provider(provider)
3125+
else:
3126+
provider_obj = provider
3127+
self._provider = provider
3128+
self._model_name = model_name
3129+
self.client = provider_obj.client
3130+
3131+
@staticmethod
3132+
def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool:
3133+
# Inline dict creation and single pass for possible strict attribute
3134+
tool_dict = {
3135+
'type': 'function',
3136+
'function': {
3137+
'name': f.name,
3138+
'description': f.description,
3139+
'parameters': f.parameters_json_schema,
3140+
},
3141+
}
3142+
if f.strict is not None:
3143+
tool_dict['function']['strict'] = f.strict
3144+
return ChatCompletionInputTool.parse_obj_as_instance(tool_dict) # type: ignore
3145+
"""
3146+
3147+
original_code = """from dataclasses import dataclass
3148+
from pydantic_ai.providers import Provider, infer_provider
3149+
from pydantic_ai_slim.pydantic_ai.models import Model
3150+
from pydantic_ai_slim.pydantic_ai.tools import ToolDefinition
3151+
from typing import Literal
3152+
3153+
try:
3154+
import aiohttp as aiohttp_
3155+
from math import pi as PI, sin as sine
3156+
from huggingface_hub import (
3157+
AsyncInferenceClient,
3158+
ChatCompletionInputMessage,
3159+
ChatCompletionInputMessageChunk,
3160+
ChatCompletionInputTool,
3161+
ChatCompletionInputToolCall,
3162+
ChatCompletionInputURL,
3163+
ChatCompletionOutput,
3164+
ChatCompletionOutputMessage,
3165+
ChatCompletionStreamOutput,
3166+
)
3167+
from huggingface_hub.errors import HfHubHTTPError
3168+
3169+
except ImportError as _import_error:
3170+
raise ImportError(
3171+
'Please install `huggingface_hub` to use Hugging Face Inference Providers, '
3172+
'you can use the `huggingface` optional group — `pip install "pydantic-ai-slim[huggingface]"`'
3173+
) from _import_error
3174+
3175+
if True:
3176+
import requests
3177+
3178+
__all__ = (
3179+
'HuggingFaceModel',
3180+
'HuggingFaceModelSettings',
3181+
)
3182+
3183+
@dataclass(init=False)
3184+
class HuggingFaceModel(Model):
3185+
3186+
def __init__(
3187+
self,
3188+
model_name: str,
3189+
*,
3190+
provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface',
3191+
):
3192+
self._model_name = model_name
3193+
self._provider = provider
3194+
if isinstance(provider, str):
3195+
provider = infer_provider(provider)
3196+
self.client = provider.client
3197+
3198+
@staticmethod
3199+
def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool:
3200+
tool_param: ChatCompletionInputTool = ChatCompletionInputTool.parse_obj_as_instance( # type: ignore
3201+
{
3202+
'type': 'function',
3203+
'function': {
3204+
'name': f.name,
3205+
'description': f.description,
3206+
'parameters': f.parameters_json_schema,
3207+
},
3208+
}
3209+
)
3210+
if f.strict is not None:
3211+
tool_param['function']['strict'] = f.strict
3212+
return tool_param
3213+
"""
3214+
3215+
3216+
function_name: str = "HuggingFaceModel._map_tool_definition"
3217+
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
3218+
new_code: str = replace_functions_and_add_imports(
3219+
source_code=original_code,
3220+
function_names=[function_name],
3221+
optimized_code=optim_code,
3222+
module_abspath=Path(__file__).resolve(),
3223+
preexisting_objects=preexisting_objects,
3224+
project_root_path=Path(__file__).resolve().parent.resolve(),
3225+
)
3226+
3227+
assert not re.search(r"^import requests\b", new_code, re.MULTILINE) # conditional simple import: import <name>
3228+
assert not re.search(r"^import aiohttp as aiohttp_\b", new_code, re.MULTILINE) # conditional alias import: import <name> as <alias>
3229+
assert not re.search(r"^from math import pi as PI, sin as sine\b", new_code, re.MULTILINE) # conditional multiple aliases imports
3230+
assert "from huggingface_hub import AsyncInferenceClient, ChatCompletionInputTool" not in new_code # conditional from import

0 commit comments

Comments
 (0)