From cc0034be5db4262f9e674ee61c2b1f7cbbec340a Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Fri, 26 Sep 2025 17:46:46 -0700 Subject: [PATCH 1/2] star imports bug --- codeflash/code_utils/code_extractor.py | 89 +++++++++++- tests/test_add_needed_imports_from_module.py | 143 +++++++++++++++++++ 2 files changed, 229 insertions(+), 3 deletions(-) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 52cb80a41..6e032290f 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -272,6 +272,8 @@ def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None: if child.module is None: continue module = self.get_full_dotted_name(child.module) + if isinstance(child.names, cst.ImportStar): + continue for alias in child.names: if isinstance(alias, cst.ImportAlias): name = alias.name.value @@ -414,6 +416,73 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str: return transformed_module.code +def resolve_star_import(module_name: str, project_root: Path) -> set[str]: + try: + module_path = module_name.replace(".", "/") + possible_paths = [project_root / f"{module_path}.py", project_root / f"{module_path}/__init__.py"] + + module_file = None + for path in possible_paths: + if path.exists(): + module_file = path + break + + if module_file is None: + logger.warning(f"Could not find module file for {module_name}, skipping star import resolution") + return set() + + with module_file.open(encoding="utf8") as f: + module_code = f.read() + + tree = ast.parse(module_code) + + all_names = None + for node in ast.walk(tree): + if ( + isinstance(node, ast.Assign) + and len(node.targets) == 1 + and isinstance(node.targets[0], ast.Name) + and node.targets[0].id == "__all__" + ): + if isinstance(node.value, (ast.List, ast.Tuple)): + all_names = [] + for elt in node.value.elts: + if isinstance(elt, ast.Constant) and isinstance(elt.value, str): + all_names.append(elt.value) + elif isinstance(elt, ast.Str): # Python < 3.8 compatibility + all_names.append(elt.s) + break + + if all_names is not None: + return set(all_names) + + public_names = set() + for node in tree.body: + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + if not node.name.startswith("_"): + public_names.add(node.name) + elif isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and not target.id.startswith("_"): + public_names.add(target.id) + elif isinstance(node, ast.AnnAssign): + if isinstance(node.target, ast.Name) and not node.target.id.startswith("_"): + public_names.add(node.target.id) + elif isinstance(node, ast.Import) or ( + isinstance(node, ast.ImportFrom) and not any(alias.name == "*" for alias in node.names) + ): + for alias in node.names: + name = alias.asname or alias.name + if not name.startswith("_"): + public_names.add(name) + + return public_names # noqa: TRY300 + + except Exception as e: + logger.warning(f"Error resolving star import for {module_name}: {e}") + return set() + + def add_needed_imports_from_module( src_module_code: str, dst_module_code: str, @@ -468,9 +537,23 @@ def add_needed_imports_from_module( f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps ): continue # Skip adding imports for helper functions already in the context - if f"{mod}.{obj}" not in dotted_import_collector.imports: - AddImportsVisitor.add_needed_import(dst_context, mod, obj) - RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj) + + # Handle star imports by resolving them to actual symbol names + if obj == "*": + resolved_symbols = resolve_star_import(mod, project_root) + logger.debug(f"Resolved star import from {mod}: {resolved_symbols}") + + for symbol in resolved_symbols: + if ( + f"{mod}.{symbol}" not in helper_functions_fqn + and f"{mod}.{symbol}" not in dotted_import_collector.imports + ): + AddImportsVisitor.add_needed_import(dst_context, mod, symbol) + RemoveImportsVisitor.remove_unused_import(dst_context, mod, symbol) + else: + if f"{mod}.{obj}" not in dotted_import_collector.imports: + AddImportsVisitor.add_needed_import(dst_context, mod, obj) + RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj) except Exception as e: logger.exception(f"Error adding imports to destination module code: {e}") return dst_module_code diff --git a/tests/test_add_needed_imports_from_module.py b/tests/test_add_needed_imports_from_module.py index 4f04948a5..374b353a7 100644 --- a/tests/test_add_needed_imports_from_module.py +++ b/tests/test_add_needed_imports_from_module.py @@ -3,6 +3,10 @@ from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects from codeflash.code_utils.code_replacer import replace_functions_and_add_imports +import tempfile +from codeflash.code_utils.code_extractor import resolve_star_import, DottedImportCollector +import libcst as cst +from codeflash.models.models import FunctionParent def test_add_needed_imports_from_module0() -> None: src_module = '''import ast @@ -349,3 +353,142 @@ def build_parent_map(self, nodes: Dict, base: Optional[bool] = False) -> Dict[st project_root_path=Path(__file__).resolve().parent.resolve(), ) assert new_code == expected + + + + +def test_resolve_star_import_with_all_defined(): + """Test resolve_star_import when __all__ is explicitly defined.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + test_module = project_root / 'test_module.py' + + # Create a test module with __all__ definition + test_module.write_text(''' +__all__ = ['public_function', 'PublicClass'] + +def public_function(): + pass + +def _private_function(): + pass + +class PublicClass: + pass + +class AnotherPublicClass: + """Not in __all__ so should be excluded.""" + pass +''') + + symbols = resolve_star_import('test_module', project_root) + expected_symbols = {'public_function', 'PublicClass'} + assert symbols == expected_symbols + + +def test_resolve_star_import_without_all_defined(): + """Test resolve_star_import when __all__ is not defined - should include all public symbols.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + test_module = project_root / 'test_module.py' + + # Create a test module without __all__ definition + test_module.write_text(''' +def public_func(): + pass + +def _private_func(): + pass + +class PublicClass: + pass + +PUBLIC_VAR = 42 +_private_var = 'secret' +''') + + symbols = resolve_star_import('test_module', project_root) + expected_symbols = {'public_func', 'PublicClass', 'PUBLIC_VAR'} + assert symbols == expected_symbols + + +def test_resolve_star_import_nonexistent_module(): + """Test resolve_star_import with non-existent module - should return empty set.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + symbols = resolve_star_import('nonexistent_module', project_root) + assert symbols == set() + + +def test_dotted_import_collector_skips_star_imports(): + """Test that DottedImportCollector correctly skips star imports.""" + code_with_star_import = ''' +from typing import * +from pathlib import Path +from collections import defaultdict +import os +''' + + module = cst.parse_module(code_with_star_import) + collector = DottedImportCollector() + module.visit(collector) + + # Should collect regular imports but skip the star import + expected_imports = { + 'pathlib.Path', + 'collections.defaultdict', + 'os' + } + assert collector.imports == expected_imports + # Ensure the star import from typing is not collected + assert not any('typing' in imp for imp in collector.imports) + + +def test_add_needed_imports_with_star_import_resolution(): + """Test add_needed_imports_from_module correctly handles star imports by resolving them.""" + with tempfile.TemporaryDirectory() as tmpdir: + project_root = Path(tmpdir) + + # Create a source module that exports symbols + src_module = project_root / 'source_module.py' + src_module.write_text(''' +__all__ = ['UtilFunction', 'HelperClass'] + +def UtilFunction(): + pass + +class HelperClass: + pass +''') + + # Create source code that uses star import + src_code = ''' +from source_module import * + +def my_function(): + helper = HelperClass() + UtilFunction() + return helper +''' + + # Destination code that needs the imports resolved + dst_code = ''' +def my_function(): + helper = HelperClass() + UtilFunction() + return helper +''' + + src_path = project_root / 'src.py' + dst_path = project_root / 'dst.py' + src_path.write_text(src_code) + + result = add_needed_imports_from_module( + src_code, dst_code, src_path, dst_path, project_root + ) + + # The result should have individual imports instead of star import + assert 'from source_module import' in result + assert 'HelperClass' in result and 'UtilFunction' in result + assert 'from source_module import *' not in result From 455a86de489ac0d31430ae6d99e5be76f0a5a322 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Fri, 26 Sep 2025 18:04:14 -0700 Subject: [PATCH 2/2] exact test --- tests/test_add_needed_imports_from_module.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/tests/test_add_needed_imports_from_module.py b/tests/test_add_needed_imports_from_module.py index 374b353a7..cb24cbc50 100644 --- a/tests/test_add_needed_imports_from_module.py +++ b/tests/test_add_needed_imports_from_module.py @@ -435,14 +435,8 @@ def test_dotted_import_collector_skips_star_imports(): module.visit(collector) # Should collect regular imports but skip the star import - expected_imports = { - 'pathlib.Path', - 'collections.defaultdict', - 'os' - } + expected_imports = {'collections.defaultdict', 'os', 'pathlib.Path'} assert collector.imports == expected_imports - # Ensure the star import from typing is not collected - assert not any('typing' in imp for imp in collector.imports) def test_add_needed_imports_with_star_import_resolution(): @@ -489,6 +483,11 @@ def my_function(): ) # The result should have individual imports instead of star import - assert 'from source_module import' in result - assert 'HelperClass' in result and 'UtilFunction' in result - assert 'from source_module import *' not in result + expected_result = '''from source_module import HelperClass, UtilFunction + +def my_function(): + helper = HelperClass() + UtilFunction() + return helper +''' + assert result == expected_result