Skip to content

Conversation

@KRRT7
Copy link
Contributor

@KRRT7 KRRT7 commented Sep 27, 2025

PR Type

Bug fix, Enhancement, Tests


Description

  • Skip collecting star imports in collector

  • Resolve star imports to concrete symbols

  • Add import resolution into add_needed_imports

  • Introduce tests for star imports handling


Diagram Walkthrough

flowchart LR
  collector["DottedImportCollector skips star imports"] -- "avoids '*'" --> imports["Accurate import set"]
  resolver["resolve_star_import(module, project_root)"] -- "__all__ or public names" --> symbols["Resolved symbols"]
  addImports["add_needed_imports_from_module"] -- "obj == '*'" --> resolver
  symbols -- "add explicit imports" --> addImports
  tests["New tests"] -- "validate behavior" --> collector
  tests -- "validate behavior" --> resolver
  tests -- "validate behavior" --> addImports
Loading

File Walkthrough

Relevant files
Enhancement
code_extractor.py
Star import resolution and collector skip                               

codeflash/code_utils/code_extractor.py

  • Ignore cst.ImportFrom star imports during collection
  • Add resolve_star_import to map '*' to symbols
  • Enhance add_needed_imports to expand star imports
  • Add logging and robust fallbacks
+86/-3   
Tests
test_add_needed_imports_from_module.py
Tests for star import resolution and skipping                       

tests/test_add_needed_imports_from_module.py

  • Add tests for resolve_star_import (__all__/public)
  • Test non-existent module returns empty set
  • Verify DottedImportCollector skips star imports
  • Test add_needed_imports expands star imports
+143/-0 

@github-actions
Copy link

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 3 🔵🔵🔵⚪⚪
🧪 PR contains tests
🔒 No security concerns identified
⚡ Recommended focus areas for review

Robustness

Star import resolution relies on reading files under project root only and simple AST parsing; it will miss symbols re-exported via submodules, dynamic all construction, or imports guarded by try/except. Validate behavior on packages with init.py that re-export from submodules and ensure acceptable fallbacks.

def resolve_star_import(module_name: str, project_root: Path) -> set[str]:
    try:
        module_path = module_name.replace(".", "/")
        possible_paths = [project_root / f"{module_path}.py", project_root / f"{module_path}/__init__.py"]

        module_file = None
        for path in possible_paths:
            if path.exists():
                module_file = path
                break

        if module_file is None:
            logger.warning(f"Could not find module file for {module_name}, skipping star import resolution")
            return set()

        with module_file.open(encoding="utf8") as f:
            module_code = f.read()

        tree = ast.parse(module_code)

        all_names = None
        for node in ast.walk(tree):
            if (
                isinstance(node, ast.Assign)
                and len(node.targets) == 1
                and isinstance(node.targets[0], ast.Name)
                and node.targets[0].id == "__all__"
            ):
                if isinstance(node.value, (ast.List, ast.Tuple)):
                    all_names = []
                    for elt in node.value.elts:
                        if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
                            all_names.append(elt.value)
                        elif isinstance(elt, ast.Str):  # Python < 3.8 compatibility
                            all_names.append(elt.s)
                break

        if all_names is not None:
            return set(all_names)

        public_names = set()
        for node in tree.body:
            if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
                if not node.name.startswith("_"):
                    public_names.add(node.name)
            elif isinstance(node, ast.Assign):
                for target in node.targets:
                    if isinstance(target, ast.Name) and not target.id.startswith("_"):
                        public_names.add(target.id)
            elif isinstance(node, ast.AnnAssign):
                if isinstance(node.target, ast.Name) and not node.target.id.startswith("_"):
                    public_names.add(node.target.id)
            elif isinstance(node, ast.Import) or (
                isinstance(node, ast.ImportFrom) and not any(alias.name == "*" for alias in node.names)
            ):
                for alias in node.names:
                    name = alias.asname or alias.name
                    if not name.startswith("_"):
                        public_names.add(name)

        return public_names  # noqa: TRY300

    except Exception as e:
        logger.warning(f"Error resolving star import for {module_name}: {e}")
        return set()
False Positives

Public name inference includes names imported from other modules and assigns them as exportable if not underscored, which may add unnecessary imports. Confirm this heuristic won’t pollute destination modules or create circular imports.

public_names = set()
for node in tree.body:
    if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
        if not node.name.startswith("_"):
            public_names.add(node.name)
    elif isinstance(node, ast.Assign):
        for target in node.targets:
            if isinstance(target, ast.Name) and not target.id.startswith("_"):
                public_names.add(target.id)
    elif isinstance(node, ast.AnnAssign):
        if isinstance(node.target, ast.Name) and not node.target.id.startswith("_"):
            public_names.add(node.target.id)
    elif isinstance(node, ast.Import) or (
        isinstance(node, ast.ImportFrom) and not any(alias.name == "*" for alias in node.names)
    ):
        for alias in node.names:
            name = alias.asname or alias.name
            if not name.startswith("_"):
                public_names.add(name)
Performance

File-system probing and AST parsing for each star import may be expensive on large projects. Consider caching resolve_star_import results per module during a run.

def resolve_star_import(module_name: str, project_root: Path) -> set[str]:
    try:
        module_path = module_name.replace(".", "/")
        possible_paths = [project_root / f"{module_path}.py", project_root / f"{module_path}/__init__.py"]

        module_file = None
        for path in possible_paths:
            if path.exists():
                module_file = path
                break

        if module_file is None:
            logger.warning(f"Could not find module file for {module_name}, skipping star import resolution")
            return set()

        with module_file.open(encoding="utf8") as f:
            module_code = f.read()

        tree = ast.parse(module_code)

        all_names = None
        for node in ast.walk(tree):
            if (
                isinstance(node, ast.Assign)
                and len(node.targets) == 1
                and isinstance(node.targets[0], ast.Name)
                and node.targets[0].id == "__all__"
            ):
                if isinstance(node.value, (ast.List, ast.Tuple)):
                    all_names = []
                    for elt in node.value.elts:
                        if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
                            all_names.append(elt.value)
                        elif isinstance(elt, ast.Str):  # Python < 3.8 compatibility
                            all_names.append(elt.s)
                break

        if all_names is not None:
            return set(all_names)

        public_names = set()
        for node in tree.body:
            if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
                if not node.name.startswith("_"):
                    public_names.add(node.name)
            elif isinstance(node, ast.Assign):
                for target in node.targets:
                    if isinstance(target, ast.Name) and not target.id.startswith("_"):
                        public_names.add(target.id)
            elif isinstance(node, ast.AnnAssign):
                if isinstance(node.target, ast.Name) and not node.target.id.startswith("_"):
                    public_names.add(node.target.id)
            elif isinstance(node, ast.Import) or (
                isinstance(node, ast.ImportFrom) and not any(alias.name == "*" for alias in node.names)
            ):
                for alias in node.names:
                    name = alias.asname or alias.name
                    if not name.startswith("_"):
                        public_names.add(name)

        return public_names  # noqa: TRY300

    except Exception as e:
        logger.warning(f"Error resolving star import for {module_name}: {e}")
        return set()

@github-actions
Copy link

PR Code Suggestions ✨

Explore these optional code suggestions:

CategorySuggestion                                                                                                                                    Impact
Possible issue
Robustly detect star imports

This only catches the case where names is exactly an ImportStar, but names can be a
sequence containing ImportStar or mixed aliases. Guard against any star within
child.names to avoid iterating it and mis-collecting. Skip the entire import when
any star is present.

codeflash/code_utils/code_extractor.py [275-276]

-if isinstance(child.names, cst.ImportStar):
+if any(isinstance(n, cst.ImportStar) for n in (child.names if isinstance(child.names, (list, tuple)) else [child.names])):
     continue
Suggestion importance[1-10]: 7

__

Why: Correctly notes that child.names for ImportFrom is a sequence and star could appear within it; the PR currently only checks the top-level type. Improves robustness with minimal risk and aligns with the collector's intent to skip star imports.

Medium
Normalize imported symbol names

alias.name can be an ast.alias attribute that includes dots (e.g., pkg.mod) and
asname may be None; comparing to "" as a bare string can fail for non-str types in
older ASTs. Normalize using getattr(alias, "name", "") == "
", and when adding names
prefer alias.asname else the last segment of alias.name to match symbols actually
introduced by import. This prevents leaking package-qualified names.

codeflash/code_utils/code_extractor.py [471-477]

 elif isinstance(node, ast.Import) or (
-    isinstance(node, ast.ImportFrom) and not any(alias.name == "*" for alias in node.names)
+    isinstance(node, ast.ImportFrom) and not any(getattr(a, "name", "") == "*" for a in node.names)
 ):
     for alias in node.names:
-        name = alias.asname or alias.name
-        if not name.startswith("_"):
-            public_names.add(name)
+        introduced = alias.asname if alias.asname else (alias.name.rsplit(".", 1)[-1] if isinstance(alias.name, str) else alias.name)
+        if isinstance(introduced, str) and not introduced.startswith("_"):
+            public_names.add(introduced)
Suggestion importance[1-10]: 6

__

Why: Suggests using last segment of alias.name and safer star detection, reducing leakage of qualified names into public_names. Useful but relatively minor; current code already avoids stars and typically works.

Low
General
Handle extended __all__ patterns

This misses common patterns like all += [...], all.extend([...]), or
annotated assignments all: list[str] = [...]. Support AugAssign, AnnAssign, and
simple Call extends to avoid under-resolving symbols. Merge discovered names across
multiple occurrences before falling back to public names.

codeflash/code_utils/code_extractor.py [439-457]

+all_names = None
 for node in ast.walk(tree):
-    if (
-        isinstance(node, ast.Assign)
-        and len(node.targets) == 1
-        and isinstance(node.targets[0], ast.Name)
-        and node.targets[0].id == "__all__"
-    ):
+    # __all__ = [...]
+    if isinstance(node, ast.Assign) and len(node.targets) == 1 and isinstance(node.targets[0], ast.Name) and node.targets[0].id == "__all__":
         if isinstance(node.value, (ast.List, ast.Tuple)):
-            all_names = []
+            names = []
             for elt in node.value.elts:
                 if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
-                    all_names.append(elt.value)
-                elif isinstance(elt, ast.Str):  # Python < 3.8 compatibility
-                    all_names.append(elt.s)
-        break
+                    names.append(elt.value)
+                elif isinstance(elt, ast.Str):
+                    names.append(elt.s)
+            all_names = set(names) if all_names is None else all_names.union(names)
+    # __all__: list[str] = [...]
+    elif isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name) and node.target.id == "__all__" and isinstance(node.value, (ast.List, ast.Tuple)):
+        names = []
+        for elt in node.value.elts:
+            if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
+                names.append(elt.value)
+            elif isinstance(elt, ast.Str):
+                names.append(elt.s)
+        all_names = set(names) if all_names is None else all_names.union(names)
+    # __all__ += [...]
+    elif isinstance(node, ast.AugAssign) and isinstance(node.target, ast.Name) and node.target.id == "__all__" and isinstance(node.value, (ast.List, ast.Tuple)):
+        names = []
+        for elt in node.value.elts:
+            if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
+                names.append(elt.value)
+            elif isinstance(elt, ast.Str):
+                names.append(elt.s)
+        all_names = set(names) if all_names is None else all_names.union(names)
+    # __all__.extend([...])
+    elif isinstance(node, ast.Expr) and isinstance(node.value, ast.Call):
+        call = node.value
+        if isinstance(call.func, ast.Attribute) and call.func.attr == "extend" and isinstance(call.func.value, ast.Name) and call.func.value.id == "__all__" and call.args and isinstance(call.args[0], (ast.List, ast.Tuple)):
+            names = []
+            for elt in call.args[0].elts:
+                if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
+                    names.append(elt.value)
+                elif isinstance(elt, ast.Str):
+                    names.append(elt.s)
+            all_names = set(names) if all_names is None else all_names.union(names)
+if all_names is not None:
+    return set(all_names)
Suggestion importance[1-10]: 7

__

Why: Extends resolve_star_import to handle common __all__ patterns (AnnAssign, AugAssign, .extend), which increases correctness for star resolution. It's a reasonable enhancement though not strictly critical.

Medium

misrasaurabh1
misrasaurabh1 previously approved these changes Sep 27, 2025
@KRRT7 KRRT7 requested a review from misrasaurabh1 September 27, 2025 01:11
@KRRT7 KRRT7 enabled auto-merge September 27, 2025 01:11
@KRRT7 KRRT7 merged commit 565d65b into main Sep 29, 2025
18 of 22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants