Skip to content

Commit 3650c34

Browse files
Add prefer_stubs configuration (#2437) (#2438)
(cherry picked from commit ee06feb) Co-authored-by: Jacob Walls <[email protected]>
1 parent a7ff092 commit 3650c34

File tree

5 files changed

+31
-14
lines changed

5 files changed

+31
-14
lines changed

ChangeLog

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ Release date: TBA
1717

1818
Closes pylint-dev/pylint#9139
1919

20+
* Add ``AstroidManager.prefer_stubs`` attribute to control the astroid 3.2.0 feature that prefers stubs.
21+
22+
Refs pylint-dev/#9626
23+
Refs pylint-dev/#9623
24+
2025

2126
What's New in astroid 3.2.0?
2227
============================

astroid/interpreter/_import/spec.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,14 +161,10 @@ def find_module(
161161
pass
162162
submodule_path = sys.path
163163

164-
# We're looping on pyi first because if a pyi exists there's probably a reason
165-
# (i.e. the code is hard or impossible to parse), so we take pyi into account
166-
# But we're not quite ready to do this for numpy, see https://github.com/pylint-dev/astroid/pull/2375
167-
suffixes = (".pyi", ".py", importlib.machinery.BYTECODE_SUFFIXES[0])
168-
numpy_suffixes = (".py", ".pyi", importlib.machinery.BYTECODE_SUFFIXES[0])
164+
suffixes = (".py", ".pyi", importlib.machinery.BYTECODE_SUFFIXES[0])
169165
for entry in submodule_path:
170166
package_directory = os.path.join(entry, modname)
171-
for suffix in numpy_suffixes if "numpy" in entry else suffixes:
167+
for suffix in suffixes:
172168
package_file_name = "__init__" + suffix
173169
file_path = os.path.join(package_directory, package_file_name)
174170
if os.path.isfile(file_path):

astroid/manager.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class AstroidManager:
6161
"extension_package_whitelist": set(),
6262
"module_denylist": set(),
6363
"_transform": TransformVisitor(),
64+
"prefer_stubs": False,
6465
}
6566

6667
def __init__(self) -> None:
@@ -73,6 +74,7 @@ def __init__(self) -> None:
7374
]
7475
self.module_denylist = AstroidManager.brain["module_denylist"]
7576
self._transform = AstroidManager.brain["_transform"]
77+
self.prefer_stubs = AstroidManager.brain["prefer_stubs"]
7678

7779
@property
7880
def always_load_extensions(self) -> bool:
@@ -111,6 +113,14 @@ def unregister_transform(self):
111113
def builtins_module(self) -> nodes.Module:
112114
return self.astroid_cache["builtins"]
113115

116+
@property
117+
def prefer_stubs(self) -> bool:
118+
return AstroidManager.brain["prefer_stubs"]
119+
120+
@prefer_stubs.setter
121+
def prefer_stubs(self, value: bool) -> None:
122+
AstroidManager.brain["prefer_stubs"] = value
123+
114124
def visit_transforms(self, node: nodes.NodeNG) -> InferenceResult:
115125
"""Visit the transforms and apply them to the given *node*."""
116126
return self._transform.visit(node)
@@ -136,7 +146,9 @@ def ast_from_file(
136146
# Call get_source_file() only after a cache miss,
137147
# since it calls os.path.exists().
138148
try:
139-
filepath = get_source_file(filepath, include_no_ext=True)
149+
filepath = get_source_file(
150+
filepath, include_no_ext=True, prefer_stubs=self.prefer_stubs
151+
)
140152
source = True
141153
except NoSourceFile:
142154
pass

astroid/modutils.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,12 @@
4444

4545

4646
if sys.platform.startswith("win"):
47-
PY_SOURCE_EXTS = ("pyi", "pyw", "py")
47+
PY_SOURCE_EXTS = ("py", "pyw", "pyi")
48+
PY_SOURCE_EXTS_STUBS_FIRST = ("pyi", "pyw", "py")
4849
PY_COMPILED_EXTS = ("dll", "pyd")
4950
else:
50-
PY_SOURCE_EXTS = ("pyi", "py")
51+
PY_SOURCE_EXTS = ("py", "pyi")
52+
PY_SOURCE_EXTS_STUBS_FIRST = ("pyi", "py")
5153
PY_COMPILED_EXTS = ("so",)
5254

5355

@@ -484,7 +486,9 @@ def get_module_files(
484486
return files
485487

486488

487-
def get_source_file(filename: str, include_no_ext: bool = False) -> str:
489+
def get_source_file(
490+
filename: str, include_no_ext: bool = False, prefer_stubs: bool = False
491+
) -> str:
488492
"""Given a python module's file name return the matching source file
489493
name (the filename will be returned identically if it's already an
490494
absolute path to a python source file).
@@ -499,7 +503,7 @@ def get_source_file(filename: str, include_no_ext: bool = False) -> str:
499503
base, orig_ext = os.path.splitext(filename)
500504
if orig_ext == ".pyi" and os.path.exists(f"{base}{orig_ext}"):
501505
return f"{base}{orig_ext}"
502-
for ext in PY_SOURCE_EXTS if "numpy" not in filename else reversed(PY_SOURCE_EXTS):
506+
for ext in PY_SOURCE_EXTS_STUBS_FIRST if prefer_stubs else PY_SOURCE_EXTS:
503507
source_path = f"{base}.{ext}"
504508
if os.path.exists(source_path):
505509
return source_path
@@ -671,8 +675,7 @@ def _has_init(directory: str) -> str | None:
671675
else return None.
672676
"""
673677
mod_or_pack = os.path.join(directory, "__init__")
674-
exts = reversed(PY_SOURCE_EXTS) if "numpy" in directory else PY_SOURCE_EXTS
675-
for ext in (*exts, "pyc", "pyo"):
678+
for ext in (*PY_SOURCE_EXTS, "pyc", "pyo"):
676679
if os.path.exists(mod_or_pack + "." + ext):
677680
return mod_or_pack + "." + ext
678681
return None

tests/test_modutils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,8 @@ def test_pyi_preferred(self) -> None:
300300
package = resources.find("pyi_data/find_test")
301301
module = os.path.join(package, "__init__.py")
302302
self.assertEqual(
303-
modutils.get_source_file(module), os.path.normpath(module) + "i"
303+
modutils.get_source_file(module, prefer_stubs=True),
304+
os.path.normpath(module) + "i",
304305
)
305306

306307

0 commit comments

Comments
 (0)