diff --git a/src/codegen/sdk/core/external_module.py b/src/codegen/sdk/core/external_module.py index efa41d751..9a61b297f 100644 --- a/src/codegen/sdk/core/external_module.py +++ b/src/codegen/sdk/core/external_module.py @@ -35,16 +35,14 @@ class ExternalModule( """ node_type: Literal[NodeType.EXTERNAL] = NodeType.EXTERNAL - _import: Import | None = None - def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: CodebaseGraph, import_name: Name, import_node: Import | None = None) -> None: + def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: CodebaseGraph, import_name: Name) -> None: self.node_id = G.add_node(self) super().__init__(ts_node, file_node_id, G, None) self._name_node = import_name self.return_type = StubPlaceholder(parent=self) assert self._idx_key not in self.G._ext_module_idx self.G._ext_module_idx[self._idx_key] = self.node_id - self._import = import_node @property def _idx_key(self) -> str: @@ -70,7 +68,7 @@ def from_import(cls, imp: Import) -> ExternalModule: Returns: ExternalModule: A new ExternalModule instance representing the external module. """ - return cls(imp.ts_node, imp.file_node_id, imp.G, imp._unique_node, imp) + return cls(imp.ts_node, imp.file_node_id, imp.G, imp._unique_node) @property @reader @@ -138,7 +136,7 @@ def viz(self) -> VizNode: @noapidoc @reader def resolve_attribute(self, name: str) -> ExternalModule | None: - return self._import.resolve_attribute(name) or self + return self @noapidoc @commiter diff --git a/src/codegen/sdk/core/import_resolution.py b/src/codegen/sdk/core/import_resolution.py index 9003b262e..1c35dcf2c 100644 --- a/src/codegen/sdk/core/import_resolution.py +++ b/src/codegen/sdk/core/import_resolution.py @@ -11,7 +11,6 @@ from codegen.sdk.core.expressions.name import Name from codegen.sdk.core.external_module import ExternalModule from codegen.sdk.core.interfaces.chainable import Chainable -from codegen.sdk.core.interfaces.has_attribute import HasAttribute from codegen.sdk.core.interfaces.usable import Usable from codegen.sdk.core.statements.import_statement import ImportStatement from codegen.sdk.enums import EdgeType, ImportType, NodeType @@ -58,7 +57,7 @@ class ImportResolution(Generic[TSourceFile]): @apidoc -class Import(Usable[ImportStatement], Chainable, Generic[TSourceFile], HasAttribute[TSourceFile]): +class Import(Usable[ImportStatement], Chainable, Generic[TSourceFile]): """Represents a single symbol being imported. For example, this is one `Import` in Python (and similar applies to Typescript, etc.): @@ -116,7 +115,7 @@ def __rich_repr__(self) -> rich.repr.Result: @noapidoc @abstractmethod - def resolve_import(self, base_path: str | None = None, *, add_module_name: str | None = None) -> ImportResolution[TSourceFile] | None: + def resolve_import(self, base_path: str | None = None) -> ImportResolution[TSourceFile] | None: """Resolves the import to a symbol defined outside the file. Returns an ImportResolution object. @@ -663,17 +662,6 @@ def remove_if_unused(self) -> None: ): self.remove() - @noapidoc - @reader - def resolve_attribute(self, attribute: str) -> TSourceFile | None: - # Handles implicit namespace imports in python - if not isinstance(self._imported_symbol(), ExternalModule): - return None - resolved = self.resolve_import(add_module_name=attribute) - if resolved: - return resolved.symbol or resolved.from_file - return None - TImport = TypeVar("TImport", bound="Import") diff --git a/src/codegen/sdk/python/file.py b/src/codegen/sdk/python/file.py index a89bfe65d..36595e165 100644 --- a/src/codegen/sdk/python/file.py +++ b/src/codegen/sdk/python/file.py @@ -6,7 +6,7 @@ from codegen.sdk.core.file import SourceFile from codegen.sdk.core.interface import Interface from codegen.sdk.enums import ImportType, ProgrammingLanguage -from codegen.sdk.extensions.utils import cached_property, iter_all_descendants +from codegen.sdk.extensions.utils import iter_all_descendants from codegen.sdk.python import PyAssignment from codegen.sdk.python.class_definition import PyClass from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock @@ -20,7 +20,6 @@ if TYPE_CHECKING: from codegen.sdk.codebase.codebase_graph import CodebaseGraph - from codegen.sdk.core.import_resolution import WildcardImport from codegen.sdk.python.symbol import PySymbol @@ -174,20 +173,3 @@ def add_import_from_import_string(self, import_string: str) -> None: def remove_unused_exports(self) -> None: """Removes unused exports from the file. NO-OP for python""" pass - - @cached_property - @noapidoc - @reader(cache=True) - def valid_import_names(self) -> dict[str, PySymbol | PyImport | WildcardImport[PyImport]]: - """Returns a dict mapping name => Symbol (or import) in this file that can be imported from - another file. - """ - if self.name == "__init__": - ret = {} - if self.directory: - for file in self.directory: - if file.name == "__init__": - continue - ret[file.name] = file - return ret - return super().valid_import_names diff --git a/src/codegen/sdk/python/import_resolution.py b/src/codegen/sdk/python/import_resolution.py index 64e3b5f1c..070d815d6 100644 --- a/src/codegen/sdk/python/import_resolution.py +++ b/src/codegen/sdk/python/import_resolution.py @@ -82,13 +82,10 @@ def imported_exports(self) -> list[Exportable]: @noapidoc @reader - def resolve_import(self, base_path: str | None = None, *, add_module_name: str | None = None) -> ImportResolution[PyFile] | None: + def resolve_import(self, base_path: str | None = None) -> ImportResolution[PyFile] | None: base_path = base_path or self.G.projects[0].base_path or "" module_source = self.module.source if self.module else "" - symbol_name = self.symbol_name.source if self.symbol_name else "" - if add_module_name: - module_source += f".{symbol_name}" - symbol_name = add_module_name + # If import is relative, convert to absolute path if module_source.startswith("."): module_source = self._relative_to_absolute_import(module_source) @@ -102,7 +99,7 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str | # `from a.b.c import foo` filepath = os.path.join( base_path, - module_source.replace(".", "/") + "/" + symbol_name + ".py", + module_source.replace(".", "/") + "/" + self.symbol_name.source + ".py", ) if file := self.G.get_file(filepath): return ImportResolution(from_file=file, symbol=None, imports_file=True) @@ -117,22 +114,22 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str | filepath = module_source.replace(".", "/") + ".py" filepath = os.path.join(base_path, filepath) if file := self.G.get_file(filepath): - symbol = file.get_node_by_name(symbol_name) + symbol = file.get_node_by_name(self.symbol_name.source) return ImportResolution(from_file=file, symbol=symbol) # =====[ Check if `module/__init__.py` file exists in the graph ]===== filepath = filepath.replace(".py", "/__init__.py") if from_file := self.G.get_file(filepath): - symbol = from_file.get_node_by_name(symbol_name) + symbol = from_file.get_node_by_name(self.symbol_name.source) return ImportResolution(from_file=from_file, symbol=symbol) # =====[ Case: Can't resolve the import ]===== if base_path == "": # Try to resolve with "src" as the base path - return self.resolve_import(base_path="src", add_module_name=add_module_name) + return self.resolve_import(base_path="src") if base_path == "src": # Try "test" next - return self.resolve_import(base_path="test", add_module_name=add_module_name) + return self.resolve_import(base_path="test") # if not G_override: # for resolver in G.import_resolvers: diff --git a/src/codegen/sdk/typescript/import_resolution.py b/src/codegen/sdk/typescript/import_resolution.py index 9084b3e09..810843483 100644 --- a/src/codegen/sdk/typescript/import_resolution.py +++ b/src/codegen/sdk/typescript/import_resolution.py @@ -197,7 +197,7 @@ def resolved_symbol(self) -> Symbol | ExternalModule | TSFile | None: return resolved_symbol @reader - def resolve_import(self, base_path: str | None = None, *, add_module_name: str | None = None) -> ImportResolution[TSFile] | None: + def resolve_import(self, base_path: str | None = None) -> ImportResolution[TSFile] | None: """Resolves an import statement to its target file and symbol. This method is used by GraphBuilder to resolve import statements to their target files and symbols. It handles both relative and absolute imports, diff --git a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py index 07587393a..0da5ae316 100644 --- a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py +++ b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py @@ -249,64 +249,3 @@ def c_sym(): assert "c_sym" in b_file.valid_symbol_names assert "a_sym" in c_file.valid_symbol_names assert "b_sym" in c_file.valid_symbol_names.keys() - - -def test_import_resolution_nested_module(tmpdir: str) -> None: - """Tests import resolution works with nested module imports""" - # language=python - with get_codebase_session( - tmpdir, - files={ - "a/b/c.py": """ -def d(): - pass -""", - "consumer.py": """ -from a import b - -b.c.d() -""", - }, - ) as codebase: - consumer_file: SourceFile = codebase.get_file("consumer.py") - c_file: SourceFile = codebase.get_file("a/b/c.py") - - # Verify import resolution - assert len(consumer_file.imports) == 1 - - # Verify function call resolution - d_func = c_file.get_function("d") - call_sites = d_func.call_sites - assert len(call_sites) == 1 - assert call_sites[0].file == consumer_file - - -def test_import_resolution_nested_module_init(tmpdir: str) -> None: - """Tests import resolution works with nested module imports""" - # language=python - with get_codebase_session( - tmpdir, - files={ - "a/b/c.py": """ -def d(): - pass -""", - "a/b/__init__.py": """""", - "consumer.py": """ -from a import b - -b.c.d() -""", - }, - ) as codebase: - consumer_file: SourceFile = codebase.get_file("consumer.py") - c_file: SourceFile = codebase.get_file("a/b/c.py") - - # Verify import resolution - assert len(consumer_file.imports) == 1 - - # Verify function call resolution - d_func = c_file.get_function("d") - call_sites = d_func.call_sites - assert len(call_sites) == 1 - assert call_sites[0].file == consumer_file