Skip to content

Commit 2f62cfe

Browse files
tomcodgentkfoss
andauthored
[CG-10936] fix: pypath resolution not resolving for __init__ properly (#717)
# Motivation <!-- Why is this change necessary? --> # Content <!-- Please include a summary of the change --> # Testing <!-- How was the change tested? --> # Please check the following before marking your PR as ready for review - [ ] I have added tests for my changes - [ ] I have updated the documentation or added new documentation as needed --------- Co-authored-by: tomcodgen <[email protected]>
1 parent 276404e commit 2f62cfe

File tree

2 files changed

+71
-1
lines changed

2 files changed

+71
-1
lines changed

src/codegen/sdk/python/import_resolution.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,24 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str |
146146
else:
147147
return ImportResolution(from_file=file, symbol=symbol)
148148

149-
# =====[ Check if `module/__init__.py` file exists in the graph ]=====
149+
# =====[ Check if `module/__init__.py` file exists in the graph with custom resolve path or sys.path enabled ]=====
150150
filepath = filepath.replace(".py", "/__init__.py")
151+
if len(self.ctx.config.import_resolution_paths) > 0 or self.ctx.config.py_resolve_syspath:
152+
# Handle resolve overrides first if both is set
153+
resolve_paths: list[str] = self.ctx.config.import_resolution_paths + (sys.path if self.ctx.config.py_resolve_syspath else [])
154+
if from_file := self._file_by_custom_resolve_paths(resolve_paths, filepath):
155+
symbol = from_file.get_node_by_name(symbol_name)
156+
if symbol is None:
157+
if from_file.get_node_from_wildcard_chain(symbol_name):
158+
return ImportResolution(from_file=from_file, symbol=None, imports_file=True)
159+
else:
160+
# This is most likely a broken import
161+
return ImportResolution(from_file=from_file, symbol=None)
162+
163+
else:
164+
return ImportResolution(from_file=from_file, symbol=symbol)
165+
166+
# =====[ Check if `module/__init__.py` file exists in the graph ]=====
151167
if from_file := self.ctx.get_file(filepath):
152168
symbol = from_file.get_node_by_name(symbol_name)
153169
if symbol is None:

tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import TYPE_CHECKING
22

3+
from codegen.sdk.codebase.config import TestFlags
34
from codegen.sdk.codebase.factory.get_session import get_codebase_session
45

56
if TYPE_CHECKING:
@@ -854,3 +855,56 @@ def test_import_resolution_init_wildcard_chainging_deep(tmpdir: str) -> None:
854855

855856
assert len(symb.usages) == 2
856857
assert symb.symbol_usages == [test1, imp]
858+
859+
860+
def test_import_resolution_paths_init(tmpdir: str) -> None:
861+
cfg = TestFlags.model_copy()
862+
cfg.debug = False ##Disable to ignore binary expression edge duplicate
863+
cfg.import_resolution_paths = ["package"]
864+
865+
# language=python
866+
content1 = """
867+
COMMON_AVAILABLE_STREAMS = [
868+
1,
869+
2,
870+
3,
871+
4,
872+
5,
873+
6,
874+
]
875+
876+
PLAN_MODEL_AVAILABLE_STREAMS = COMMON_AVAILABLE_STREAMS + [
877+
4,
878+
2
879+
]
880+
"""
881+
content2 = """
882+
from main.dir.dir2 import PLAN_MODEL_AVAILABLE_STREAMS
883+
def do_smth():
884+
foo=PLAN_MODEL_AVAILABLE_STREAMS
885+
print(foo)
886+
887+
do_smth()
888+
"""
889+
with get_codebase_session(
890+
tmpdir=tmpdir,
891+
config=cfg,
892+
files={
893+
"package/main/dir/dir2/file1.py": "bar=2",
894+
"package/main/dir/dir2/__init__.py": content1,
895+
"package/main/dir/__init__.py": content2,
896+
"package/main/__init__.py": "",
897+
"start.py": """from main.dir.dir2.file1 import bar
898+
print(bar)
899+
""",
900+
},
901+
) as codebase:
902+
file1: SourceFile = codebase.get_file("package/main/dir/dir2/__init__.py")
903+
p_m = file1.get_symbol("PLAN_MODEL_AVAILABLE_STREAMS")
904+
file2: SourceFile = codebase.get_file("package/main/dir/__init__.py")
905+
dosmth = file2.get_symbol("do_smth")
906+
import_pm = file2.get_import("PLAN_MODEL_AVAILABLE_STREAMS")
907+
908+
assert len(p_m.usages) == 3
909+
assert p_m.symbol_usages == [dosmth, import_pm]
910+
assert len(file1.symbols) != 1

0 commit comments

Comments
 (0)