Skip to content

Commit 49056c4

Browse files
split importlib into own submodule for better typing
1 parent e24c6a6 commit 49056c4

File tree

3 files changed

+82
-87
lines changed

3 files changed

+82
-87
lines changed

src/pluggy/_importlib.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from __future__ import annotations
2+
3+
import sys
4+
from typing import Callable, Iterable, Any
5+
6+
if sys.version_info >= (3, 8):
7+
from importlib.metadata import distributions
8+
else:
9+
from importlib_metadata import distributions
10+
11+
12+
class DistFacade:
13+
"""Emulate a pkg_resources Distribution"""
14+
15+
# turn Any to Distribution as soon as the typing details for them fit
16+
def __init__(self, dist: Any) -> None:
17+
self._dist = dist
18+
19+
@property
20+
def project_name(self) -> str:
21+
name: str = self.metadata["name"]
22+
return name
23+
24+
def __getattr__(self, attr: str, default: Any | None = None) -> Any:
25+
return getattr(self._dist, attr, default)
26+
27+
def __dir__(self) -> list[str]:
28+
return sorted(dir(self._dist) + ["_dist", "project_name"])
29+
30+
31+
def iter_entrypoint_loaders(
32+
group: str, name: str | None
33+
) -> Iterable[tuple[DistFacade, str, Callable[[], object]]]:
34+
for dist in list(distributions()):
35+
legacy = DistFacade(dist)
36+
for ep in dist.entry_points:
37+
if ep.group == group and name is None or name == ep.name:
38+
yield legacy, ep.name, ep.load

src/pluggy/_manager.py

Lines changed: 33 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,19 @@
1+
from __future__ import annotations
12
import inspect
2-
import sys
33
import types
44
import warnings
55
from typing import (
66
Any,
77
Callable,
88
cast,
9-
Dict,
109
Iterable,
11-
List,
1210
Mapping,
13-
Optional,
1411
Sequence,
15-
Set,
16-
Tuple,
1712
TYPE_CHECKING,
18-
Union,
1913
)
2014

2115
from . import _tracing
16+
from ._importlib import DistFacade, iter_entrypoint_loaders
2217
from ._result import _Result
2318
from ._callers import _multicall
2419
from ._hooks import (
@@ -33,11 +28,6 @@
3328
_Plugin,
3429
)
3530

36-
if sys.version_info >= (3, 8):
37-
from importlib import metadata as importlib_metadata
38-
else:
39-
import importlib_metadata
40-
4131
if TYPE_CHECKING:
4232
from typing_extensions import Final
4333

@@ -68,24 +58,6 @@ def __init__(self, plugin: _Plugin, message: str) -> None:
6858
self.plugin = plugin
6959

7060

71-
class DistFacade:
72-
"""Emulate a pkg_resources Distribution"""
73-
74-
def __init__(self, dist: importlib_metadata.Distribution) -> None:
75-
self._dist = dist
76-
77-
@property
78-
def project_name(self) -> str:
79-
name: str = self.metadata["name"]
80-
return name
81-
82-
def __getattr__(self, attr: str, default=None):
83-
return getattr(self._dist, attr, default)
84-
85-
def __dir__(self) -> List[str]:
86-
return sorted(dir(self._dist) + ["_dist", "project_name"])
87-
88-
8961
class PluginManager:
9062
"""Core class which manages registration of plugin objects and 1:N hook
9163
calling.
@@ -111,11 +83,11 @@ class PluginManager:
11183
)
11284

11385
def __init__(self, project_name: str) -> None:
114-
self.project_name: "Final" = project_name
115-
self._name2plugin: "Final[Dict[str, _Plugin]]" = {}
116-
self._plugin_distinfo: "Final[List[Tuple[_Plugin, DistFacade]]]" = []
117-
self.trace: "Final" = _tracing.TagTracer().get("pluginmanage")
118-
self.hook: "Final" = _HookRelay()
86+
self.project_name: Final = project_name
87+
self._name2plugin: Final[dict[str, _Plugin]] = {}
88+
self._plugin_distinfo: Final[list[tuple[_Plugin, DistFacade]]] = []
89+
self.trace: Final = _tracing.TagTracer().get("pluginmanage")
90+
self.hook: Final = _HookRelay()
11991
self._inner_hookexec = _multicall
12092

12193
def _hookexec(
@@ -124,12 +96,12 @@ def _hookexec(
12496
methods: Sequence[HookImpl],
12597
kwargs: Mapping[str, object],
12698
firstresult: bool,
127-
) -> Union[object, List[object]]:
99+
) -> object | list[object]:
128100
# called from all hookcaller instances.
129101
# enable_tracing will set its own wrapping function at self._inner_hookexec
130102
return self._inner_hookexec(hook_name, methods, kwargs, firstresult)
131103

132-
def register(self, plugin: _Plugin, name: Optional[str] = None) -> Optional[str]:
104+
def register(self, plugin: _Plugin, name: str | None = None) -> str | None:
133105
"""Register a plugin and return its name.
134106
135107
If a name is not specified, a name is generated using
@@ -167,7 +139,7 @@ def register(self, plugin: _Plugin, name: Optional[str] = None) -> Optional[str]
167139
method: _HookImplFunction[object] = getattr(plugin, name)
168140
hookimpl = HookImpl(plugin, plugin_name, method, hookimpl_opts)
169141
name = hookimpl_opts.get("specname") or name
170-
hook: Optional[_HookCaller] = getattr(self.hook, name, None)
142+
hook: _HookCaller | None = getattr(self.hook, name, None)
171143
if hook is None:
172144
hook = _HookCaller(name, self._hookexec)
173145
setattr(self.hook, name, hook)
@@ -177,25 +149,20 @@ def register(self, plugin: _Plugin, name: Optional[str] = None) -> Optional[str]
177149
hook._add_hookimpl(hookimpl)
178150
return plugin_name
179151

180-
def parse_hookimpl_opts(
181-
self, plugin: _Plugin, name: str
182-
) -> Optional["_HookImplOpts"]:
152+
def parse_hookimpl_opts(self, plugin: _Plugin, name: str) -> _HookImplOpts | None:
183153
method: object = getattr(plugin, name)
184154
if not inspect.isroutine(method):
185155
return None
186156
try:
187-
res: Optional["_HookImplOpts"] = getattr(
157+
res: _HookImplOpts | None = getattr(
188158
method, self.project_name + "_impl", None
189159
)
190160
except Exception:
191161
res = {} # type: ignore[assignment]
192-
if res is not None and not isinstance(res, dict):
193-
# false positive
194-
res = None
195162
return res
196163

197164
def unregister(
198-
self, plugin: Optional[_Plugin] = None, name: Optional[str] = None
165+
self, plugin: _Plugin | None = None, name: str | None = None
199166
) -> _Plugin:
200167
"""Unregister a plugin and all of its hook implementations.
201168
@@ -241,7 +208,7 @@ def add_hookspecs(self, module_or_class: _Namespace) -> None:
241208
for name in dir(module_or_class):
242209
spec_opts = self.parse_hookspec_opts(module_or_class, name)
243210
if spec_opts is not None:
244-
hc: Optional[_HookCaller] = getattr(self.hook, name, None)
211+
hc: _HookCaller | None = getattr(self.hook, name, None)
245212
if hc is None:
246213
hc = _HookCaller(name, self._hookexec, module_or_class, spec_opts)
247214
setattr(self.hook, name, hc)
@@ -259,14 +226,12 @@ def add_hookspecs(self, module_or_class: _Namespace) -> None:
259226

260227
def parse_hookspec_opts(
261228
self, module_or_class: _Namespace, name: str
262-
) -> Optional["_HookSpecOpts"]:
229+
) -> _HookSpecOpts | None:
263230
method: HookSpec = getattr(module_or_class, name)
264-
opts: Optional[_HookSpecOpts] = getattr(
265-
method, self.project_name + "_spec", None
266-
)
231+
opts: _HookSpecOpts | None = getattr(method, self.project_name + "_spec", None)
267232
return opts
268233

269-
def get_plugins(self) -> Set[Any]:
234+
def get_plugins(self) -> set[Any]:
270235
"""Return a set of all registered plugin objects."""
271236
return set(self._name2plugin.values())
272237

@@ -282,18 +247,18 @@ def get_canonical_name(self, plugin: _Plugin) -> str:
282247
To obtain the name of n registered plugin use :meth:`get_name(plugin)
283248
<get_name>` instead.
284249
"""
285-
name: Optional[str] = getattr(plugin, "__name__", None)
250+
name: str | None = getattr(plugin, "__name__", None)
286251
return name or str(id(plugin))
287252

288-
def get_plugin(self, name: str) -> Optional[Any]:
253+
def get_plugin(self, name: str) -> Any | None:
289254
"""Return the plugin registered under the given name, if any."""
290255
return self._name2plugin.get(name)
291256

292257
def has_plugin(self, name: str) -> bool:
293258
"""Return whether a plugin with the given name is registered."""
294259
return self.get_plugin(name) is not None
295260

296-
def get_name(self, plugin: _Plugin) -> Optional[str]:
261+
def get_name(self, plugin: _Plugin) -> str | None:
297262
"""Return the name the plugin is registered under, or ``None`` if
298263
is isn't."""
299264
for name, val in self._name2plugin.items():
@@ -353,9 +318,7 @@ def check_pending(self) -> None:
353318
% (name, hookimpl.plugin),
354319
)
355320

356-
def load_setuptools_entrypoints(
357-
self, group: str, name: Optional[str] = None
358-
) -> int:
321+
def load_setuptools_entrypoints(self, group: str, name: str | None = None) -> int:
359322
"""Load modules from querying the specified setuptools ``group``.
360323
361324
:param str group: Entry point group to load plugins.
@@ -364,32 +327,26 @@ def load_setuptools_entrypoints(
364327
:return: The number of plugins loaded by this call.
365328
"""
366329
count = 0
367-
for dist in list(importlib_metadata.distributions()):
368-
for ep in dist.entry_points:
369-
if (
370-
ep.group != group
371-
or (name is not None and ep.name != name)
372-
# already registered
373-
or self.get_plugin(ep.name)
374-
or self.is_blocked(ep.name)
375-
):
376-
continue
377-
plugin = ep.load()
378-
self.register(plugin, name=ep.name)
379-
self._plugin_distinfo.append((plugin, DistFacade(dist)))
380-
count += 1
330+
for dist, ep_name, loader in iter_entrypoint_loaders(group, name):
331+
if self.get_plugin(ep_name) or self.is_blocked(ep_name):
332+
continue
333+
# already registered
334+
plugin = loader()
335+
self.register(plugin, name=ep_name)
336+
self._plugin_distinfo.append((plugin, dist))
337+
count += 1
381338
return count
382339

383-
def list_plugin_distinfo(self) -> List[Tuple[_Plugin, DistFacade]]:
340+
def list_plugin_distinfo(self) -> list[tuple[_Plugin, DistFacade]]:
384341
"""Return a list of (plugin, distinfo) pairs for all
385342
setuptools-registered plugins."""
386343
return list(self._plugin_distinfo)
387344

388-
def list_name_plugin(self) -> List[Tuple[str, _Plugin]]:
345+
def list_name_plugin(self) -> list[tuple[str, _Plugin]]:
389346
"""Return a list of (name, plugin) pairs for all registered plugins."""
390347
return list(self._name2plugin.items())
391348

392-
def get_hookcallers(self, plugin: _Plugin) -> Optional[List[_HookCaller]]:
349+
def get_hookcallers(self, plugin: _Plugin) -> list[_HookCaller] | None:
393350
"""Get all hook callers for the specified plugin."""
394351
if self.get_name(plugin) is None:
395352
return None
@@ -422,7 +379,7 @@ def traced_hookexec(
422379
hook_impls: Sequence[HookImpl],
423380
caller_kwargs: Mapping[str, object],
424381
firstresult: bool,
425-
) -> Union[object, List[object]]:
382+
) -> object | list[object]:
426383
before(hook_name, hook_impls, caller_kwargs)
427384
outcome = _Result.from_call(
428385
lambda: oldcall(hook_name, hook_impls, caller_kwargs, firstresult)

testing/test_pluginmanager.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""
22
``PluginManager`` unit and public API testing.
33
"""
4+
from __future__ import annotations
45
import pytest
5-
from typing import Any, List
6+
from typing import Any, Callable
67

78
from pluggy import (
89
HookCallError,
@@ -11,7 +12,7 @@
1112
PluginManager,
1213
PluginValidationError,
1314
)
14-
from pluggy._manager import importlib_metadata
15+
import pluggy._importlib
1516

1617

1718
hookspec = HookspecMarker("example")
@@ -265,13 +266,12 @@ def test_with_result_memorized(pm: PluginManager, result_callback: bool) -> None
265266
correctly applies the ``result_callback`` function, when provided,
266267
to the result from calling each newly registered hook.
267268
"""
268-
out = []
269+
out: list[int] = []
270+
callback: Callable[[int], None] | None
269271
if not result_callback:
270272
callback = None
271273
else:
272-
273-
def callback(res) -> None:
274-
out.append(res)
274+
callback = out.append
275275

276276
class Hooks:
277277
@hookspec(historic=True)
@@ -381,7 +381,7 @@ def he_method1(self, arg):
381381
class Plugin1:
382382
@hookimpl
383383
def he_method1(self, arg):
384-
0 / 0
384+
return 0 / 0
385385

386386
pm.register(Plugin1())
387387
with pytest.raises(HookCallError):
@@ -553,7 +553,7 @@ class Distribution:
553553
def my_distributions():
554554
return (dist,)
555555

556-
monkeypatch.setattr(importlib_metadata, "distributions", my_distributions)
556+
monkeypatch.setattr(pluggy._importlib, "distributions", my_distributions)
557557
num = pm.load_setuptools_entrypoints("hello")
558558
assert num == 1
559559
plugin = pm.get_plugin("myname")
@@ -564,13 +564,13 @@ def my_distributions():
564564
assert len(ret) == 1
565565
assert len(ret[0]) == 2
566566
assert ret[0][0] == plugin
567-
assert ret[0][1]._dist == dist # type: ignore[comparison-overlap]
567+
assert ret[0][1]._dist == dist
568568
num = pm.load_setuptools_entrypoints("hello")
569569
assert num == 0 # no plugin loaded by this call
570570

571571

572572
def test_add_tracefuncs(he_pm: PluginManager) -> None:
573-
out: List[Any] = []
573+
out: list[Any] = []
574574

575575
class api1:
576576
@hookimpl
@@ -623,7 +623,7 @@ def he_method1(self):
623623
raise ValueError()
624624

625625
he_pm.register(api1())
626-
out: List[Any] = []
626+
out: list[Any] = []
627627
he_pm.trace.root.setwriter(out.append)
628628
undo = he_pm.enable_tracing()
629629
try:

0 commit comments

Comments
 (0)