Skip to content

Commit 8320ade

Browse files
committed
Instead of presenting separate contexts for EntryPoints, unify into a single collection that can select on 'name' or 'group' or possibly other attributes. Expose that selection in the 'entry_points' function.
1 parent 71fd4a7 commit 8320ade

File tree

5 files changed

+54
-36
lines changed

5 files changed

+54
-36
lines changed

docs/using.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ a ``.load()`` method to resolve the value. There are also ``.module``,
7777
>>> eps = entry_points()
7878
>>> sorted(eps.groups)
7979
['console_scripts', 'distutils.commands', 'distutils.setup_keywords', 'egg_info.writers', 'setuptools.installation']
80-
>>> scripts = eps['console_scripts']
80+
>>> scripts = eps.select(group='console_scripts')
8181
>>> 'wheel' in scripts.names
8282
True
8383
>>> wheel = scripts['wheel']

importlib_metadata/__init__.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,6 @@ def _from_text(cls, text):
130130
config.read_string(text)
131131
return cls._from_config(config)
132132

133-
@classmethod
134-
def _from_text_for(cls, text, dist):
135-
return (ep._for(dist) for ep in cls._from_text(text))
136-
137133
def _for(self, dist):
138134
self.dist = dist
139135
return self
@@ -155,35 +151,42 @@ def __reduce__(self):
155151
(self.name, self.value, self.group),
156152
)
157153

154+
def matches(self, **params):
155+
attrs = (getattr(self, param) for param in params)
156+
return all(map(operator.eq, params.values(), attrs))
157+
158158

159159
class EntryPoints(tuple):
160160
"""
161-
An immutable collection of EntryPoint objects, retrievable by name.
161+
An immutable collection of selectable EntryPoint objects.
162162
"""
163163

164164
__slots__ = ()
165165

166-
def __getitem__(self, name) -> EntryPoint:
166+
def __getitem__(self, name) -> Union[EntryPoint, 'EntryPoints']:
167167
try:
168-
return next(ep for ep in self if ep.name == name)
169-
except Exception:
168+
match = next(iter(self.select(name=name)))
169+
return match
170+
except StopIteration:
171+
if name in self.groups:
172+
return self._group_getitem(name)
170173
raise KeyError(name)
171174

175+
def _group_getitem(self, name):
176+
"""
177+
For backward compatability, supply .__getitem__ for groups.
178+
"""
179+
msg = "GroupedEntryPoints.__getitem__ is deprecated for groups. Use select."
180+
warnings.warn(msg, DeprecationWarning)
181+
return self.select(group=name)
182+
183+
def select(self, **params):
184+
return EntryPoints(ep for ep in self if ep.matches(**params))
185+
172186
@property
173187
def names(self):
174188
return set(ep.name for ep in self)
175189

176-
177-
class GroupedEntryPoints(tuple):
178-
"""
179-
An immutable collection of EntryPoint objects, retrievable by group.
180-
"""
181-
182-
__slots__ = ()
183-
184-
def __getitem__(self, group) -> EntryPoints:
185-
return EntryPoints(ep for ep in self if ep.group == group)
186-
187190
@property
188191
def groups(self):
189192
return set(ep.group for ep in self)
@@ -193,9 +196,13 @@ def get(self, group, default=None):
193196
For backward compatibility, supply .get
194197
"""
195198
is_flake8 = any('flake8' in str(frame) for frame in inspect.stack())
196-
msg = "GroupedEntryPoints.get is deprecated. Just use __getitem__."
199+
msg = "GroupedEntryPoints.get is deprecated. Use select."
197200
is_flake8 or warnings.warn(msg, DeprecationWarning)
198-
return self[group] or default
201+
return self.select(group=group) or default
202+
203+
@classmethod
204+
def _from_text_for(cls, text, dist):
205+
return cls(ep._for(dist) for ep in EntryPoint._from_text(text))
199206

200207

201208
class PackagePath(pathlib.PurePosixPath):
@@ -353,8 +360,7 @@ def version(self):
353360

354361
@property
355362
def entry_points(self):
356-
eps = EntryPoint._from_text_for(self.read_text('entry_points.txt'), self)
357-
return GroupedEntryPoints(eps)
363+
return EntryPoints._from_text_for(self.read_text('entry_points.txt'), self)
358364

359365
@property
360366
def files(self):
@@ -687,13 +693,13 @@ def version(distribution_name):
687693
return distribution(distribution_name).version
688694

689695

690-
def entry_points():
696+
def entry_points(**params):
691697
"""Return EntryPoint objects for all installed packages.
692698
693699
:return: EntryPoint objects for all installed packages.
694700
"""
695701
eps = itertools.chain.from_iterable(dist.entry_points for dist in distributions())
696-
return GroupedEntryPoints(eps)
702+
return EntryPoints(eps).select(**params)
697703

698704

699705
def files(distribution_name):

tests/test_api.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,25 +67,25 @@ def test_read_text(self):
6767
def test_entry_points(self):
6868
eps = entry_points()
6969
assert 'entries' in eps.groups
70-
entries = eps['entries']
70+
entries = eps.select(group='entries')
7171
assert 'main' in entries.names
7272
ep = entries['main']
7373
self.assertEqual(ep.value, 'mod:main')
7474
self.assertEqual(ep.extras, [])
7575

7676
def test_entry_points_distribution(self):
77-
entries = entry_points()['entries']
77+
entries = entry_points(group='entries')
7878
for entry in ("main", "ns:sub"):
7979
ep = entries[entry]
8080
self.assertIn(ep.dist.name, ('distinfo-pkg', 'egginfo-pkg'))
8181
self.assertEqual(ep.dist.version, "1.0.0")
8282

8383
def test_entry_points_missing_name(self):
8484
with self.assertRaises(KeyError):
85-
entry_points()['entries']['missing']
85+
entry_points(group='entries')['missing']
8686

8787
def test_entry_points_missing_group(self):
88-
assert entry_points()['missing'] == ()
88+
assert entry_points(group='missing') == ()
8989

9090
def test_entry_points_dict_construction(self):
9191
"""
@@ -94,16 +94,28 @@ def test_entry_points_dict_construction(self):
9494
Capture this now deprecated use-case.
9595
"""
9696
with warnings.catch_warnings(record=True) as caught:
97-
eps = dict(entry_points()['entries'])
97+
eps = dict(entry_points(group='entries'))
9898

9999
assert 'main' in eps
100-
assert eps['main'] == entry_points()['entries']['main']
100+
assert eps['main'] == entry_points(group='entries')['main']
101101

102102
# check warning
103103
expected = next(iter(caught))
104104
assert expected.category is DeprecationWarning
105105
assert "Construction of dict of EntryPoints is deprecated" in str(expected)
106106

107+
def test_entry_points_groups_getitem(self):
108+
"""
109+
Prior versions of entry_points() returned a dict. Ensure
110+
that callers using '.__getitem__()' are supported but warned to
111+
migrate.
112+
"""
113+
with warnings.catch_warnings(record=True):
114+
entry_points()['entries'] == entry_points(group='entries')
115+
116+
with self.assertRaises(KeyError):
117+
entry_points()['missing']
118+
107119
def test_entry_points_groups_get(self):
108120
"""
109121
Prior versions of entry_points() returned a dict. Ensure
@@ -113,7 +125,7 @@ def test_entry_points_groups_get(self):
113125
with warnings.catch_warnings(record=True):
114126
entry_points().get('missing', 'default') == 'default'
115127
entry_points().get('entries', 'default') == entry_points()['entries']
116-
entry_points().get('missing', ()) == entry_points()['missing']
128+
entry_points().get('missing', ()) == ()
117129

118130
def test_metadata_for_this_package(self):
119131
md = metadata('egginfo-pkg')

tests/test_main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,11 @@ def test_import_nonexistent_module(self):
5858
importlib.import_module('does_not_exist')
5959

6060
def test_resolve(self):
61-
ep = entry_points()['entries']['main']
61+
ep = entry_points(group='entries')['main']
6262
self.assertEqual(ep.load().__name__, "main")
6363

6464
def test_entrypoint_with_colon_in_name(self):
65-
ep = entry_points()['entries']['ns:sub']
65+
ep = entry_points(group='entries')['ns:sub']
6666
self.assertEqual(ep.value, 'mod:main')
6767

6868
def test_resolve_without_attr(self):

tests/test_zip.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_zip_version_does_not_match(self):
4545
version('definitely-not-installed')
4646

4747
def test_zip_entry_points(self):
48-
scripts = entry_points()['console_scripts']
48+
scripts = entry_points(group='console_scripts')
4949
entry_point = scripts['example']
5050
self.assertEqual(entry_point.value, 'example:main')
5151
entry_point = scripts['Example']

0 commit comments

Comments
 (0)