Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 31 additions & 17 deletions codeflash/code_utils/code_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,19 @@ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
self.last_import_line = self.current_line


class ConditionalImportCollector(cst.CSTVisitor):
"""Collect imports inside top-level conditionals (e.g., if TYPE_CHECKING, try/except)."""
class DottedImportCollector(cst.CSTVisitor):
"""Collects all top-level imports from a Python module in normalized dotted format, including top-level conditional imports like `if TYPE_CHECKING:`.

Examples
--------
import os ==> "os"
import dbt.adapters.factory ==> "dbt.adapters.factory"
from pathlib import Path ==> "pathlib.Path"
from recce.adapter.base import BaseAdapter ==> "recce.adapter.base.BaseAdapter"
from typing import Any, List, Optional ==> "typing.Any", "typing.List", "typing.Optional"
from recce.util.lineage import ( build_column_key, filter_dependency_maps) ==> "recce.util.lineage.build_column_key", "recce.util.lineage.filter_dependency_maps"

"""

def __init__(self) -> None:
self.imports: set[str] = set()
Expand All @@ -217,7 +228,10 @@ def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None:
for alias in child.names:
module = self.get_full_dotted_name(alias.name)
asname = alias.asname.name.value if alias.asname else alias.name.value
self.imports.add(module if module == asname else f"{module}.{asname}")
if isinstance(asname, cst.Attribute):
self.imports.add(module)
else:
self.imports.add(module if module == asname else f"{module}.{asname}")

elif isinstance(child, cst.ImportFrom):
if child.module is None:
Expand All @@ -231,6 +245,7 @@ def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None:

def visit_Module(self, node: cst.Module) -> None:
self.depth = 0
self._collect_imports_from_block(node)

def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
self.depth += 1
Expand Down Expand Up @@ -388,45 +403,44 @@ def add_needed_imports_from_module(
logger.error(f"Error parsing source module code: {e}")
return dst_module_code

cond_import_collector = ConditionalImportCollector()
dotted_import_collector = DottedImportCollector()
try:
parsed_dst_module = cst.parse_module(dst_module_code)
parsed_dst_module.visit(cond_import_collector)
parsed_dst_module.visit(dotted_import_collector)
except cst.ParserSyntaxError as e:
logger.exception(f"Syntax error in destination module code: {e}")
return dst_module_code # Return the original code if there's a syntax error

try:
for mod in gatherer.module_imports:
if mod in cond_import_collector.imports:
continue
AddImportsVisitor.add_needed_import(dst_context, mod)
if mod not in dotted_import_collector.imports:
AddImportsVisitor.add_needed_import(dst_context, mod)
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
for mod, obj_seq in gatherer.object_mapping.items():
for obj in obj_seq:
if (
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
):
continue # Skip adding imports for helper functions already in the context
if f"{mod}.{obj}" in cond_import_collector.imports:
continue
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
if f"{mod}.{obj}" not in dotted_import_collector.imports:
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
except Exception as e:
logger.exception(f"Error adding imports to destination module code: {e}")
return dst_module_code

for mod, asname in gatherer.module_aliases.items():
if f"{mod}.{asname}" in cond_import_collector.imports:
continue
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
if f"{mod}.{asname}" not in dotted_import_collector.imports:
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname)

for mod, alias_pairs in gatherer.alias_mapping.items():
for alias_pair in alias_pairs:
if f"{mod}.{alias_pair[0]}" in helper_functions_fqn:
continue
if f"{mod}.{alias_pair[1]}" in cond_import_collector.imports:
continue
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])

if f"{mod}.{alias_pair[1]}" not in dotted_import_collector.imports:
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])

try:
Expand Down
230 changes: 229 additions & 1 deletion tests/test_add_needed_imports_from_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path

from codeflash.code_utils.code_extractor import add_needed_imports_from_module
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports


def test_add_needed_imports_from_module0() -> None:
Expand Down Expand Up @@ -121,3 +122,230 @@ def belongs_to_function(name: Name, function_name: str) -> bool:
project_root = Path("/home/roger/repos/codeflash")
new_module = add_needed_imports_from_module(src_module, dst_module, src_path, dst_path, project_root)
assert new_module == expected

def test_duplicated_imports() -> None:
optim_code = '''from dataclasses import dataclass
from recce.adapter.base import BaseAdapter
from typing import Dict, List, Optional

@dataclass
class DbtAdapter(BaseAdapter):

def build_parent_map(self, nodes: Dict, base: Optional[bool] = False) -> Dict[str, List[str]]:
manifest = self.curr_manifest if base is False else self.base_manifest

try:
parent_map_source = manifest.parent_map
except AttributeError:
parent_map_source = manifest.to_dict()["parent_map"]

node_ids = set(nodes)
parent_map = {}
for k, parents in parent_map_source.items():
if k not in node_ids:
continue
parent_map[k] = [parent for parent in parents if parent in node_ids]

return parent_map
'''

original_code = '''import json
import logging
import os
import uuid
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass, fields
from errno import ENOENT
from functools import lru_cache
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Set,
Tuple,
Type,
Union,
)

from recce.event import log_performance
from recce.exceptions import RecceException
from recce.util.cll import CLLPerformanceTracking, cll
from recce.util.lineage import (
build_column_key,
filter_dependency_maps,
find_downstream,
find_upstream,
)
from recce.util.perf_tracking import LineagePerfTracker

from ...tasks.profile import ProfileTask
from ...util.breaking import BreakingPerformanceTracking, parse_change_category

try:
import agate
import dbt.adapters.factory
from dbt.contracts.state import PreviousState
except ImportError as e:
print("Error: dbt module not found. Please install it by running:")
print("pip install dbt-core dbt-<adapter>")
raise e
from watchdog.events import FileSystemEventHandler
from watchdog.observers import Observer

from recce.adapter.base import BaseAdapter
from recce.state import ArtifactsRoot

from ...models import RunType
from ...models.types import (
CllColumn,
CllData,
CllNode,
LineageDiff,
NodeChange,
NodeDiff,
)
from ...tasks import (
HistogramDiffTask,
ProfileDiffTask,
QueryBaseTask,
QueryDiffTask,
QueryTask,
RowCountDiffTask,
RowCountTask,
Task,
TopKDiffTask,
ValueDiffDetailTask,
ValueDiffTask,
)
from .dbt_version import DbtVersion

@dataclass
class DbtAdapter(BaseAdapter):

def build_parent_map(self, nodes: Dict, base: Optional[bool] = False) -> Dict[str, List[str]]:
manifest = self.curr_manifest if base is False else self.base_manifest
manifest_dict = manifest.to_dict()

node_ids = nodes.keys()
parent_map = {}
for k, parents in manifest_dict["parent_map"].items():
if k not in node_ids:
continue
parent_map[k] = [parent for parent in parents if parent in node_ids]

return parent_map
'''
expected = '''import json
import logging
import os
import uuid
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass, fields
from errno import ENOENT
from functools import lru_cache
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Set,
Tuple,
Type,
Union,
)

from recce.event import log_performance
from recce.exceptions import RecceException
from recce.util.cll import CLLPerformanceTracking, cll
from recce.util.lineage import (
build_column_key,
filter_dependency_maps,
find_downstream,
find_upstream,
)
from recce.util.perf_tracking import LineagePerfTracker

from ...tasks.profile import ProfileTask
from ...util.breaking import BreakingPerformanceTracking, parse_change_category

try:
import agate
import dbt.adapters.factory
from dbt.contracts.state import PreviousState
except ImportError as e:
print("Error: dbt module not found. Please install it by running:")
print("pip install dbt-core dbt-<adapter>")
raise e
from watchdog.events import FileSystemEventHandler
from watchdog.observers import Observer

from recce.adapter.base import BaseAdapter
from recce.state import ArtifactsRoot

from ...models import RunType
from ...models.types import (
CllColumn,
CllData,
CllNode,
LineageDiff,
NodeChange,
NodeDiff,
)
from ...tasks import (
HistogramDiffTask,
ProfileDiffTask,
QueryBaseTask,
QueryDiffTask,
QueryTask,
RowCountDiffTask,
RowCountTask,
Task,
TopKDiffTask,
ValueDiffDetailTask,
ValueDiffTask,
)
from .dbt_version import DbtVersion

@dataclass
class DbtAdapter(BaseAdapter):

def build_parent_map(self, nodes: Dict, base: Optional[bool] = False) -> Dict[str, List[str]]:
manifest = self.curr_manifest if base is False else self.base_manifest

try:
parent_map_source = manifest.parent_map
except AttributeError:
parent_map_source = manifest.to_dict()["parent_map"]

node_ids = set(nodes)
parent_map = {}
for k, parents in parent_map_source.items():
if k not in node_ids:
continue
parent_map[k] = [parent for parent in parents if parent in node_ids]

return parent_map
'''

function_name: str = "DbtAdapter.build_parent_map"
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=[function_name],
optimized_code=optim_code,
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == expected
Loading