Skip to content

Commit f10b462

Browse files
authored
Stricter handling of submodules as attributes (#20179)
Fixes #20174 The idea is quite straightforward: we only allow `foo.bar` without explicit re-export if `foo.bar` was imported in any transitive dependency (and not in some unrelated module). Note: only `import foo.bar` takes effect, effect of using `from foo.bar import ...` is not propagated for two reasons: * It will cost large performance penalty * It is relatively obscure Python feature that may be considered by some as "implementation detail"
1 parent 45aa599 commit f10b462

File tree

4 files changed

+168
-34
lines changed

4 files changed

+168
-34
lines changed

mypy/build.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,7 @@ def __init__(
603603
self.options = options
604604
self.version_id = version_id
605605
self.modules: dict[str, MypyFile] = {}
606+
self.import_map: dict[str, set[str]] = {}
606607
self.missing_modules: set[str] = set()
607608
self.fg_deps_meta: dict[str, FgDepMeta] = {}
608609
# fg_deps holds the dependencies of every module that has been
@@ -623,6 +624,7 @@ def __init__(
623624
self.incomplete_namespaces,
624625
self.errors,
625626
self.plugin,
627+
self.import_map,
626628
)
627629
self.all_types: dict[Expression, Type] = {} # Enabled by export_types
628630
self.indirection_detector = TypeIndirectionVisitor()
@@ -742,6 +744,26 @@ def getmtime(self, path: str) -> int:
742744
else:
743745
return int(self.metastore.getmtime(path))
744746

747+
def correct_rel_imp(self, file: MypyFile, imp: ImportFrom | ImportAll) -> str:
748+
"""Function to correct for relative imports."""
749+
file_id = file.fullname
750+
rel = imp.relative
751+
if rel == 0:
752+
return imp.id
753+
if os.path.basename(file.path).startswith("__init__."):
754+
rel -= 1
755+
if rel != 0:
756+
file_id = ".".join(file_id.split(".")[:-rel])
757+
new_id = file_id + "." + imp.id if imp.id else file_id
758+
759+
if not new_id:
760+
self.errors.set_file(file.path, file.name, self.options)
761+
self.errors.report(
762+
imp.line, 0, "No parent module -- cannot perform relative import", blocker=True
763+
)
764+
765+
return new_id
766+
745767
def all_imported_modules_in_file(self, file: MypyFile) -> list[tuple[int, str, int]]:
746768
"""Find all reachable import statements in a file.
747769
@@ -750,27 +772,6 @@ def all_imported_modules_in_file(self, file: MypyFile) -> list[tuple[int, str, i
750772
751773
Can generate blocking errors on bogus relative imports.
752774
"""
753-
754-
def correct_rel_imp(imp: ImportFrom | ImportAll) -> str:
755-
"""Function to correct for relative imports."""
756-
file_id = file.fullname
757-
rel = imp.relative
758-
if rel == 0:
759-
return imp.id
760-
if os.path.basename(file.path).startswith("__init__."):
761-
rel -= 1
762-
if rel != 0:
763-
file_id = ".".join(file_id.split(".")[:-rel])
764-
new_id = file_id + "." + imp.id if imp.id else file_id
765-
766-
if not new_id:
767-
self.errors.set_file(file.path, file.name, self.options)
768-
self.errors.report(
769-
imp.line, 0, "No parent module -- cannot perform relative import", blocker=True
770-
)
771-
772-
return new_id
773-
774775
res: list[tuple[int, str, int]] = []
775776
for imp in file.imports:
776777
if not imp.is_unreachable:
@@ -785,7 +786,7 @@ def correct_rel_imp(imp: ImportFrom | ImportAll) -> str:
785786
ancestors.append(part)
786787
res.append((ancestor_pri, ".".join(ancestors), imp.line))
787788
elif isinstance(imp, ImportFrom):
788-
cur_id = correct_rel_imp(imp)
789+
cur_id = self.correct_rel_imp(file, imp)
789790
all_are_submodules = True
790791
# Also add any imported names that are submodules.
791792
pri = import_priority(imp, PRI_MED)
@@ -805,7 +806,7 @@ def correct_rel_imp(imp: ImportFrom | ImportAll) -> str:
805806
res.append((pri, cur_id, imp.line))
806807
elif isinstance(imp, ImportAll):
807808
pri = import_priority(imp, PRI_HIGH)
808-
res.append((pri, correct_rel_imp(imp), imp.line))
809+
res.append((pri, self.correct_rel_imp(file, imp), imp.line))
809810

810811
# Sort such that module (e.g. foo.bar.baz) comes before its ancestors (e.g. foo
811812
# and foo.bar) so that, if FindModuleCache finds the target module in a
@@ -2898,6 +2899,9 @@ def dispatch(sources: list[BuildSource], manager: BuildManager, stdout: TextIO)
28982899
manager.cache_enabled = False
28992900
graph = load_graph(sources, manager)
29002901

2902+
for id in graph:
2903+
manager.import_map[id] = set(graph[id].dependencies + graph[id].suppressed)
2904+
29012905
t1 = time.time()
29022906
manager.add_stats(
29032907
graph_size=len(graph),

mypy/nodes.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5002,27 +5002,27 @@ def local_definitions(
50025002
SYMBOL_TABLE_NODE: Final[Tag] = 61
50035003

50045004

5005-
def read_symbol(data: Buffer) -> mypy.nodes.SymbolNode:
5005+
def read_symbol(data: Buffer) -> SymbolNode:
50065006
tag = read_tag(data)
50075007
# The branches here are ordered manually by type "popularity".
50085008
if tag == VAR:
5009-
return mypy.nodes.Var.read(data)
5009+
return Var.read(data)
50105010
if tag == FUNC_DEF:
5011-
return mypy.nodes.FuncDef.read(data)
5011+
return FuncDef.read(data)
50125012
if tag == DECORATOR:
5013-
return mypy.nodes.Decorator.read(data)
5013+
return Decorator.read(data)
50145014
if tag == TYPE_INFO:
5015-
return mypy.nodes.TypeInfo.read(data)
5015+
return TypeInfo.read(data)
50165016
if tag == OVERLOADED_FUNC_DEF:
5017-
return mypy.nodes.OverloadedFuncDef.read(data)
5017+
return OverloadedFuncDef.read(data)
50185018
if tag == TYPE_VAR_EXPR:
5019-
return mypy.nodes.TypeVarExpr.read(data)
5019+
return TypeVarExpr.read(data)
50205020
if tag == TYPE_ALIAS:
5021-
return mypy.nodes.TypeAlias.read(data)
5021+
return TypeAlias.read(data)
50225022
if tag == PARAM_SPEC_EXPR:
5023-
return mypy.nodes.ParamSpecExpr.read(data)
5023+
return ParamSpecExpr.read(data)
50245024
if tag == TYPE_VAR_TUPLE_EXPR:
5025-
return mypy.nodes.TypeVarTupleExpr.read(data)
5025+
return TypeVarTupleExpr.read(data)
50265026
assert False, f"Unknown symbol tag {tag}"
50275027

50285028

mypy/semanal.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,7 @@ def __init__(
451451
incomplete_namespaces: set[str],
452452
errors: Errors,
453453
plugin: Plugin,
454+
import_map: dict[str, set[str]],
454455
) -> None:
455456
"""Construct semantic analyzer.
456457
@@ -483,6 +484,7 @@ def __init__(
483484
self.loop_depth = [0]
484485
self.errors = errors
485486
self.modules = modules
487+
self.import_map = import_map
486488
self.msg = MessageBuilder(errors, modules)
487489
self.missing_modules = missing_modules
488490
self.missing_names = [set()]
@@ -534,6 +536,16 @@ def __init__(
534536
self.type_expression_full_parse_success_count: int = 0 # Successful full parses
535537
self.type_expression_full_parse_failure_count: int = 0 # Failed full parses
536538

539+
# Imports of submodules transitively visible from given module.
540+
# This is needed to support patterns like this
541+
# [a.py]
542+
# import b
543+
# import foo
544+
# foo.bar # <- this should work even if bar is not re-exported in foo
545+
# [b.py]
546+
# import foo.bar
547+
self.transitive_submodule_imports: dict[str, set[str]] = {}
548+
537549
# mypyc doesn't properly handle implementing an abstractproperty
538550
# with a regular attribute so we make them properties
539551
@property
@@ -6687,7 +6699,7 @@ def get_module_symbol(self, node: MypyFile, name: str) -> SymbolTableNode | None
66876699
sym = names.get(name)
66886700
if not sym:
66896701
fullname = module + "." + name
6690-
if fullname in self.modules:
6702+
if fullname in self.modules and self.is_visible_import(module, fullname):
66916703
sym = SymbolTableNode(GDEF, self.modules[fullname])
66926704
elif self.is_incomplete_namespace(module):
66936705
self.record_incomplete_ref()
@@ -6706,6 +6718,40 @@ def get_module_symbol(self, node: MypyFile, name: str) -> SymbolTableNode | None
67066718
sym = None
67076719
return sym
67086720

6721+
def is_visible_import(self, base_id: str, id: str) -> bool:
6722+
if id in self.import_map[self.cur_mod_id]:
6723+
# Fast path: module is imported locally.
6724+
return True
6725+
if base_id not in self.transitive_submodule_imports:
6726+
# This is a performance optimization for a common pattern. If one module
6727+
# in a codebase uses import numpy as np; np.foo.bar, then it is likely that
6728+
# other modules use similar pattern as well. So we pre-compute transitive
6729+
# dependencies for np, to avoid possible duplicate work in the future.
6730+
self.add_transitive_submodule_imports(base_id)
6731+
if self.cur_mod_id not in self.transitive_submodule_imports:
6732+
self.add_transitive_submodule_imports(self.cur_mod_id)
6733+
return id in self.transitive_submodule_imports[self.cur_mod_id]
6734+
6735+
def add_transitive_submodule_imports(self, mod_id: str) -> None:
6736+
if mod_id not in self.import_map:
6737+
return
6738+
todo = self.import_map[mod_id]
6739+
seen = {mod_id}
6740+
result = {mod_id}
6741+
while todo:
6742+
dep = todo.pop()
6743+
if dep in seen:
6744+
continue
6745+
seen.add(dep)
6746+
if "." in dep:
6747+
result.add(dep)
6748+
if dep in self.transitive_submodule_imports:
6749+
result |= self.transitive_submodule_imports[dep]
6750+
continue
6751+
if dep in self.import_map:
6752+
todo |= self.import_map[dep]
6753+
self.transitive_submodule_imports[mod_id] = result
6754+
67096755
def is_missing_module(self, module: str) -> bool:
67106756
return module in self.missing_modules
67116757

test-data/unit/check-incremental.test

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7512,3 +7512,87 @@ tmp/impl.py:31: note: Revealed type is "builtins.object"
75127512
tmp/impl.py:32: note: Revealed type is "Union[builtins.int, builtins.str, lib.Unrelated]"
75137513
tmp/impl.py:33: note: Revealed type is "builtins.object"
75147514
tmp/impl.py:34: note: Revealed type is "builtins.object"
7515+
7516+
[case testIncrementalAccessSubmoduleWithoutExplicitImport]
7517+
import b
7518+
import a
7519+
7520+
[file a.py]
7521+
import pkg
7522+
7523+
pkg.submod.foo()
7524+
7525+
[file a.py.2]
7526+
import pkg
7527+
7528+
pkg.submod.foo()
7529+
x = 1
7530+
7531+
[file b.py]
7532+
import c
7533+
7534+
[file c.py]
7535+
from pkg import submod
7536+
7537+
[file pkg/__init__.pyi]
7538+
[file pkg/submod.pyi]
7539+
def foo() -> None: pass
7540+
[out]
7541+
tmp/a.py:3: error: "object" has no attribute "submod"
7542+
[out2]
7543+
tmp/a.py:3: error: "object" has no attribute "submod"
7544+
7545+
[case testIncrementalAccessSubmoduleWithoutExplicitImportNested]
7546+
import a
7547+
7548+
[file a.py]
7549+
import pandas
7550+
pandas.core.dtypes
7551+
7552+
[file a.py.2]
7553+
import pandas
7554+
pandas.core.dtypes
7555+
# touch
7556+
7557+
[file pandas/__init__.py]
7558+
import pandas.core.api
7559+
7560+
[file pandas/core/__init__.py]
7561+
[file pandas/core/api.py]
7562+
import pandas.core.dtypes.dtypes
7563+
7564+
[file pandas/core/dtypes/__init__.py]
7565+
[file pandas/core/dtypes/dtypes.py]
7566+
X = 0
7567+
[out]
7568+
[out2]
7569+
7570+
[case testIncrementalAccessSubmoduleWithoutExplicitImportNestedFrom]
7571+
import a
7572+
7573+
[file a.py]
7574+
import pandas
7575+
7576+
# Although this actually works at runtime, we do not support this, since
7577+
# this would cause major slowdown for a rare edge case. This test verifies
7578+
# that we fail consistently on cold and warm runs.
7579+
pandas.core.dtypes
7580+
7581+
[file a.py.2]
7582+
import pandas
7583+
pandas.core.dtypes
7584+
7585+
[file pandas/__init__.py]
7586+
import pandas.core.api
7587+
7588+
[file pandas/core/__init__.py]
7589+
[file pandas/core/api.py]
7590+
from pandas.core.dtypes.dtypes import X
7591+
7592+
[file pandas/core/dtypes/__init__.py]
7593+
[file pandas/core/dtypes/dtypes.py]
7594+
X = 0
7595+
[out]
7596+
tmp/a.py:6: error: "object" has no attribute "dtypes"
7597+
[out2]
7598+
tmp/a.py:2: error: "object" has no attribute "dtypes"

0 commit comments

Comments
 (0)