diff --git a/src/codegen/sdk/python/import_resolution.py b/src/codegen/sdk/python/import_resolution.py index 2686db76a..16f7f876b 100644 --- a/src/codegen/sdk/python/import_resolution.py +++ b/src/codegen/sdk/python/import_resolution.py @@ -146,8 +146,24 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str | else: return ImportResolution(from_file=file, symbol=symbol) - # =====[ Check if `module/__init__.py` file exists in the graph ]===== + # =====[ Check if `module/__init__.py` file exists in the graph with custom resolve path or sys.path enabled ]===== filepath = filepath.replace(".py", "/__init__.py") + if len(self.ctx.config.import_resolution_paths) > 0 or self.ctx.config.py_resolve_syspath: + # Handle resolve overrides first if both is set + resolve_paths: list[str] = self.ctx.config.import_resolution_paths + (sys.path if self.ctx.config.py_resolve_syspath else []) + if from_file := self._file_by_custom_resolve_paths(resolve_paths, filepath): + symbol = from_file.get_node_by_name(symbol_name) + if symbol is None: + if from_file.get_node_from_wildcard_chain(symbol_name): + return ImportResolution(from_file=from_file, symbol=None, imports_file=True) + else: + # This is most likely a broken import + return ImportResolution(from_file=from_file, symbol=None) + + else: + return ImportResolution(from_file=from_file, symbol=symbol) + + # =====[ Check if `module/__init__.py` file exists in the graph ]===== if from_file := self.ctx.get_file(filepath): symbol = from_file.get_node_by_name(symbol_name) if symbol is None: diff --git a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py index e10df046b..ebbe5d724 100644 --- a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py +++ b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING +from codegen.sdk.codebase.config import TestFlags from codegen.sdk.codebase.factory.get_session import get_codebase_session if TYPE_CHECKING: @@ -854,3 +855,56 @@ def test_import_resolution_init_wildcard_chainging_deep(tmpdir: str) -> None: assert len(symb.usages) == 2 assert symb.symbol_usages == [test1, imp] + + +def test_import_resolution_paths_init(tmpdir: str) -> None: + cfg = TestFlags.model_copy() + cfg.debug = False ##Disable to ignore binary expression edge duplicate + cfg.import_resolution_paths = ["package"] + + # language=python + content1 = """ + COMMON_AVAILABLE_STREAMS = [ + 1, + 2, + 3, + 4, + 5, + 6, + ] + + PLAN_MODEL_AVAILABLE_STREAMS = COMMON_AVAILABLE_STREAMS + [ + 4, + 2 + ] +""" + content2 = """ + from main.dir.dir2 import PLAN_MODEL_AVAILABLE_STREAMS + def do_smth(): + foo=PLAN_MODEL_AVAILABLE_STREAMS + print(foo) + + do_smth() + """ + with get_codebase_session( + tmpdir=tmpdir, + config=cfg, + files={ + "package/main/dir/dir2/file1.py": "bar=2", + "package/main/dir/dir2/__init__.py": content1, + "package/main/dir/__init__.py": content2, + "package/main/__init__.py": "", + "start.py": """from main.dir.dir2.file1 import bar + print(bar) + """, + }, + ) as codebase: + file1: SourceFile = codebase.get_file("package/main/dir/dir2/__init__.py") + p_m = file1.get_symbol("PLAN_MODEL_AVAILABLE_STREAMS") + file2: SourceFile = codebase.get_file("package/main/dir/__init__.py") + dosmth = file2.get_symbol("do_smth") + import_pm = file2.get_import("PLAN_MODEL_AVAILABLE_STREAMS") + + assert len(p_m.usages) == 3 + assert p_m.symbol_usages == [dosmth, import_pm] + assert len(file1.symbols) != 1