diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index bc149e5a9..0f69bed7a 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -195,8 +195,19 @@ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None: self.last_import_line = self.current_line -class ConditionalImportCollector(cst.CSTVisitor): - """Collect imports inside top-level conditionals (e.g., if TYPE_CHECKING, try/except).""" +class DottedImportCollector(cst.CSTVisitor): + """Collects all top-level imports from a Python module in normalized dotted format, including top-level conditional imports like `if TYPE_CHECKING:`. + + Examples + -------- + import os ==> "os" + import dbt.adapters.factory ==> "dbt.adapters.factory" + from pathlib import Path ==> "pathlib.Path" + from recce.adapter.base import BaseAdapter ==> "recce.adapter.base.BaseAdapter" + from typing import Any, List, Optional ==> "typing.Any", "typing.List", "typing.Optional" + from recce.util.lineage import ( build_column_key, filter_dependency_maps) ==> "recce.util.lineage.build_column_key", "recce.util.lineage.filter_dependency_maps" + + """ def __init__(self) -> None: self.imports: set[str] = set() @@ -217,7 +228,10 @@ def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None: for alias in child.names: module = self.get_full_dotted_name(alias.name) asname = alias.asname.name.value if alias.asname else alias.name.value - self.imports.add(module if module == asname else f"{module}.{asname}") + if isinstance(asname, cst.Attribute): + self.imports.add(module) + else: + self.imports.add(module if module == asname else f"{module}.{asname}") elif isinstance(child, cst.ImportFrom): if child.module is None: @@ -231,6 +245,7 @@ def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None: def visit_Module(self, node: cst.Module) -> None: self.depth = 0 + self._collect_imports_from_block(node) def visit_FunctionDef(self, node: cst.FunctionDef) -> None: self.depth += 1 @@ -388,19 +403,18 @@ def add_needed_imports_from_module( logger.error(f"Error parsing source module code: {e}") return dst_module_code - cond_import_collector = ConditionalImportCollector() + dotted_import_collector = DottedImportCollector() try: parsed_dst_module = cst.parse_module(dst_module_code) - parsed_dst_module.visit(cond_import_collector) + parsed_dst_module.visit(dotted_import_collector) except cst.ParserSyntaxError as e: logger.exception(f"Syntax error in destination module code: {e}") return dst_module_code # Return the original code if there's a syntax error try: for mod in gatherer.module_imports: - if mod in cond_import_collector.imports: - continue - AddImportsVisitor.add_needed_import(dst_context, mod) + if mod not in dotted_import_collector.imports: + AddImportsVisitor.add_needed_import(dst_context, mod) RemoveImportsVisitor.remove_unused_import(dst_context, mod) for mod, obj_seq in gatherer.object_mapping.items(): for obj in obj_seq: @@ -408,25 +422,25 @@ def add_needed_imports_from_module( f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps ): continue # Skip adding imports for helper functions already in the context - if f"{mod}.{obj}" in cond_import_collector.imports: - continue - AddImportsVisitor.add_needed_import(dst_context, mod, obj) + if f"{mod}.{obj}" not in dotted_import_collector.imports: + AddImportsVisitor.add_needed_import(dst_context, mod, obj) RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj) except Exception as e: logger.exception(f"Error adding imports to destination module code: {e}") return dst_module_code + for mod, asname in gatherer.module_aliases.items(): - if f"{mod}.{asname}" in cond_import_collector.imports: - continue - AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname) + if f"{mod}.{asname}" not in dotted_import_collector.imports: + AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname) RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname) + for mod, alias_pairs in gatherer.alias_mapping.items(): for alias_pair in alias_pairs: if f"{mod}.{alias_pair[0]}" in helper_functions_fqn: continue - if f"{mod}.{alias_pair[1]}" in cond_import_collector.imports: - continue - AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1]) + + if f"{mod}.{alias_pair[1]}" not in dotted_import_collector.imports: + AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1]) RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1]) try: diff --git a/tests/test_add_needed_imports_from_module.py b/tests/test_add_needed_imports_from_module.py index f8e52e630..4f04948a5 100644 --- a/tests/test_add_needed_imports_from_module.py +++ b/tests/test_add_needed_imports_from_module.py @@ -1,6 +1,7 @@ from pathlib import Path -from codeflash.code_utils.code_extractor import add_needed_imports_from_module +from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects +from codeflash.code_utils.code_replacer import replace_functions_and_add_imports def test_add_needed_imports_from_module0() -> None: @@ -121,3 +122,230 @@ def belongs_to_function(name: Name, function_name: str) -> bool: project_root = Path("/home/roger/repos/codeflash") new_module = add_needed_imports_from_module(src_module, dst_module, src_path, dst_path, project_root) assert new_module == expected + +def test_duplicated_imports() -> None: + optim_code = '''from dataclasses import dataclass +from recce.adapter.base import BaseAdapter +from typing import Dict, List, Optional + +@dataclass +class DbtAdapter(BaseAdapter): + + def build_parent_map(self, nodes: Dict, base: Optional[bool] = False) -> Dict[str, List[str]]: + manifest = self.curr_manifest if base is False else self.base_manifest + + try: + parent_map_source = manifest.parent_map + except AttributeError: + parent_map_source = manifest.to_dict()["parent_map"] + + node_ids = set(nodes) + parent_map = {} + for k, parents in parent_map_source.items(): + if k not in node_ids: + continue + parent_map[k] = [parent for parent in parents if parent in node_ids] + + return parent_map +''' + + original_code = '''import json +import logging +import os +import uuid +from contextlib import contextmanager +from copy import deepcopy +from dataclasses import dataclass, fields +from errno import ENOENT +from functools import lru_cache +from pathlib import Path +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Literal, + Optional, + Set, + Tuple, + Type, + Union, +) + +from recce.event import log_performance +from recce.exceptions import RecceException +from recce.util.cll import CLLPerformanceTracking, cll +from recce.util.lineage import ( + build_column_key, + filter_dependency_maps, + find_downstream, + find_upstream, +) +from recce.util.perf_tracking import LineagePerfTracker + +from ...tasks.profile import ProfileTask +from ...util.breaking import BreakingPerformanceTracking, parse_change_category + +try: + import agate + import dbt.adapters.factory + from dbt.contracts.state import PreviousState +except ImportError as e: + print("Error: dbt module not found. Please install it by running:") + print("pip install dbt-core dbt-") + raise e +from watchdog.events import FileSystemEventHandler +from watchdog.observers import Observer + +from recce.adapter.base import BaseAdapter +from recce.state import ArtifactsRoot + +from ...models import RunType +from ...models.types import ( + CllColumn, + CllData, + CllNode, + LineageDiff, + NodeChange, + NodeDiff, +) +from ...tasks import ( + HistogramDiffTask, + ProfileDiffTask, + QueryBaseTask, + QueryDiffTask, + QueryTask, + RowCountDiffTask, + RowCountTask, + Task, + TopKDiffTask, + ValueDiffDetailTask, + ValueDiffTask, +) +from .dbt_version import DbtVersion + +@dataclass +class DbtAdapter(BaseAdapter): + + def build_parent_map(self, nodes: Dict, base: Optional[bool] = False) -> Dict[str, List[str]]: + manifest = self.curr_manifest if base is False else self.base_manifest + manifest_dict = manifest.to_dict() + + node_ids = nodes.keys() + parent_map = {} + for k, parents in manifest_dict["parent_map"].items(): + if k not in node_ids: + continue + parent_map[k] = [parent for parent in parents if parent in node_ids] + + return parent_map +''' + expected = '''import json +import logging +import os +import uuid +from contextlib import contextmanager +from copy import deepcopy +from dataclasses import dataclass, fields +from errno import ENOENT +from functools import lru_cache +from pathlib import Path +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + Literal, + Optional, + Set, + Tuple, + Type, + Union, +) + +from recce.event import log_performance +from recce.exceptions import RecceException +from recce.util.cll import CLLPerformanceTracking, cll +from recce.util.lineage import ( + build_column_key, + filter_dependency_maps, + find_downstream, + find_upstream, +) +from recce.util.perf_tracking import LineagePerfTracker + +from ...tasks.profile import ProfileTask +from ...util.breaking import BreakingPerformanceTracking, parse_change_category + +try: + import agate + import dbt.adapters.factory + from dbt.contracts.state import PreviousState +except ImportError as e: + print("Error: dbt module not found. Please install it by running:") + print("pip install dbt-core dbt-") + raise e +from watchdog.events import FileSystemEventHandler +from watchdog.observers import Observer + +from recce.adapter.base import BaseAdapter +from recce.state import ArtifactsRoot + +from ...models import RunType +from ...models.types import ( + CllColumn, + CllData, + CllNode, + LineageDiff, + NodeChange, + NodeDiff, +) +from ...tasks import ( + HistogramDiffTask, + ProfileDiffTask, + QueryBaseTask, + QueryDiffTask, + QueryTask, + RowCountDiffTask, + RowCountTask, + Task, + TopKDiffTask, + ValueDiffDetailTask, + ValueDiffTask, +) +from .dbt_version import DbtVersion + +@dataclass +class DbtAdapter(BaseAdapter): + + def build_parent_map(self, nodes: Dict, base: Optional[bool] = False) -> Dict[str, List[str]]: + manifest = self.curr_manifest if base is False else self.base_manifest + + try: + parent_map_source = manifest.parent_map + except AttributeError: + parent_map_source = manifest.to_dict()["parent_map"] + + node_ids = set(nodes) + parent_map = {} + for k, parents in parent_map_source.items(): + if k not in node_ids: + continue + parent_map[k] = [parent for parent in parents if parent in node_ids] + + return parent_map +''' + + function_name: str = "DbtAdapter.build_parent_map" + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code) + new_code: str = replace_functions_and_add_imports( + source_code=original_code, + function_names=[function_name], + optimized_code=optim_code, + module_abspath=Path(__file__).resolve(), + preexisting_objects=preexisting_objects, + project_root_path=Path(__file__).resolve().parent.resolve(), + ) + assert new_code == expected