Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 86 additions & 3 deletions codeflash/code_utils/code_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None:
if child.module is None:
continue
module = self.get_full_dotted_name(child.module)
if isinstance(child.names, cst.ImportStar):
continue
for alias in child.names:
if isinstance(alias, cst.ImportAlias):
name = alias.name.value
Expand Down Expand Up @@ -414,6 +416,73 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
return transformed_module.code


def resolve_star_import(module_name: str, project_root: Path) -> set[str]:
try:
module_path = module_name.replace(".", "/")
possible_paths = [project_root / f"{module_path}.py", project_root / f"{module_path}/__init__.py"]

module_file = None
for path in possible_paths:
if path.exists():
module_file = path
break

if module_file is None:
logger.warning(f"Could not find module file for {module_name}, skipping star import resolution")
return set()

with module_file.open(encoding="utf8") as f:
module_code = f.read()

tree = ast.parse(module_code)

all_names = None
for node in ast.walk(tree):
if (
isinstance(node, ast.Assign)
and len(node.targets) == 1
and isinstance(node.targets[0], ast.Name)
and node.targets[0].id == "__all__"
):
if isinstance(node.value, (ast.List, ast.Tuple)):
all_names = []
for elt in node.value.elts:
if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
all_names.append(elt.value)
elif isinstance(elt, ast.Str): # Python < 3.8 compatibility
all_names.append(elt.s)
break

if all_names is not None:
return set(all_names)

public_names = set()
for node in tree.body:
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
if not node.name.startswith("_"):
public_names.add(node.name)
elif isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and not target.id.startswith("_"):
public_names.add(target.id)
elif isinstance(node, ast.AnnAssign):
if isinstance(node.target, ast.Name) and not node.target.id.startswith("_"):
public_names.add(node.target.id)
elif isinstance(node, ast.Import) or (
isinstance(node, ast.ImportFrom) and not any(alias.name == "*" for alias in node.names)
):
for alias in node.names:
name = alias.asname or alias.name
if not name.startswith("_"):
public_names.add(name)

return public_names # noqa: TRY300

except Exception as e:
logger.warning(f"Error resolving star import for {module_name}: {e}")
return set()


def add_needed_imports_from_module(
src_module_code: str,
dst_module_code: str,
Expand Down Expand Up @@ -468,9 +537,23 @@ def add_needed_imports_from_module(
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
):
continue # Skip adding imports for helper functions already in the context
if f"{mod}.{obj}" not in dotted_import_collector.imports:
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)

# Handle star imports by resolving them to actual symbol names
if obj == "*":
resolved_symbols = resolve_star_import(mod, project_root)
logger.debug(f"Resolved star import from {mod}: {resolved_symbols}")

for symbol in resolved_symbols:
if (
f"{mod}.{symbol}" not in helper_functions_fqn
and f"{mod}.{symbol}" not in dotted_import_collector.imports
):
AddImportsVisitor.add_needed_import(dst_context, mod, symbol)
RemoveImportsVisitor.remove_unused_import(dst_context, mod, symbol)
else:
if f"{mod}.{obj}" not in dotted_import_collector.imports:
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
except Exception as e:
logger.exception(f"Error adding imports to destination module code: {e}")
return dst_module_code
Expand Down
142 changes: 142 additions & 0 deletions tests/test_add_needed_imports_from_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports

import tempfile
from codeflash.code_utils.code_extractor import resolve_star_import, DottedImportCollector
import libcst as cst
from codeflash.models.models import FunctionParent

def test_add_needed_imports_from_module0() -> None:
src_module = '''import ast
Expand Down Expand Up @@ -349,3 +353,141 @@ def build_parent_map(self, nodes: Dict, base: Optional[bool] = False) -> Dict[st
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == expected




def test_resolve_star_import_with_all_defined():
"""Test resolve_star_import when __all__ is explicitly defined."""
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir)
test_module = project_root / 'test_module.py'

# Create a test module with __all__ definition
test_module.write_text('''
__all__ = ['public_function', 'PublicClass']

def public_function():
pass

def _private_function():
pass

class PublicClass:
pass

class AnotherPublicClass:
"""Not in __all__ so should be excluded."""
pass
''')

symbols = resolve_star_import('test_module', project_root)
expected_symbols = {'public_function', 'PublicClass'}
assert symbols == expected_symbols


def test_resolve_star_import_without_all_defined():
"""Test resolve_star_import when __all__ is not defined - should include all public symbols."""
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir)
test_module = project_root / 'test_module.py'

# Create a test module without __all__ definition
test_module.write_text('''
def public_func():
pass

def _private_func():
pass

class PublicClass:
pass

PUBLIC_VAR = 42
_private_var = 'secret'
''')

symbols = resolve_star_import('test_module', project_root)
expected_symbols = {'public_func', 'PublicClass', 'PUBLIC_VAR'}
assert symbols == expected_symbols


def test_resolve_star_import_nonexistent_module():
"""Test resolve_star_import with non-existent module - should return empty set."""
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir)

symbols = resolve_star_import('nonexistent_module', project_root)
assert symbols == set()


def test_dotted_import_collector_skips_star_imports():
"""Test that DottedImportCollector correctly skips star imports."""
code_with_star_import = '''
from typing import *
from pathlib import Path
from collections import defaultdict
import os
'''

module = cst.parse_module(code_with_star_import)
collector = DottedImportCollector()
module.visit(collector)

# Should collect regular imports but skip the star import
expected_imports = {'collections.defaultdict', 'os', 'pathlib.Path'}
assert collector.imports == expected_imports


def test_add_needed_imports_with_star_import_resolution():
"""Test add_needed_imports_from_module correctly handles star imports by resolving them."""
with tempfile.TemporaryDirectory() as tmpdir:
project_root = Path(tmpdir)

# Create a source module that exports symbols
src_module = project_root / 'source_module.py'
src_module.write_text('''
__all__ = ['UtilFunction', 'HelperClass']

def UtilFunction():
pass

class HelperClass:
pass
''')

# Create source code that uses star import
src_code = '''
from source_module import *

def my_function():
helper = HelperClass()
UtilFunction()
return helper
'''

# Destination code that needs the imports resolved
dst_code = '''
def my_function():
helper = HelperClass()
UtilFunction()
return helper
'''

src_path = project_root / 'src.py'
dst_path = project_root / 'dst.py'
src_path.write_text(src_code)

result = add_needed_imports_from_module(
src_code, dst_code, src_path, dst_path, project_root
)

# The result should have individual imports instead of star import
expected_result = '''from source_module import HelperClass, UtilFunction

def my_function():
helper = HelperClass()
UtilFunction()
return helper
'''
assert result == expected_result
Loading