Skip to content

Commit 1d05766

Browse files
committed
feat: allow registering entry points with prefixes (pytorch#1050)
1 parent ae72fa2 commit 1d05766

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

torchx/util/entrypoints.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def load_group(
5959
where the ``deferred_load_fn`` (as the name implies) defers the
6060
loading of the entrypoint (e.g. ``entrypoint.load()``) until the
6161
caller explicitly executes the funtion.
62+
If there are entry points with the group matching exactly they are the only ones returned.
63+
Otherwise all entry points that have a group ending with ``group`` are returned with a prefix.
6264
6365
For the following ``entry_point.txt``:
6466
@@ -87,14 +89,23 @@ def load_group(
8789
8890
"""
8991

90-
entrypoints = metadata.entry_points().select(group=group)
92+
entrypoints_prefixed, entrypoints_override = [], []
93+
for ep in metadata.entry_points():
94+
if ep.group == group:
95+
entrypoints_override.append(ep)
96+
elif ep.group.endswith(group):
97+
entrypoints_prefixed.append(ep)
9198

99+
entrypoints = entrypoints_override or entrypoints_prefixed
92100
if len(entrypoints) == 0:
93101
if skip_defaults:
94102
return None
95103
return default
96104

97-
eps = {}
105+
eps: Dict[str, Any] = {}
106+
if not skip_defaults and default:
107+
eps.update(default)
98108
for ep in entrypoints:
99-
eps[ep.name] = _defer_load_ep(ep)
109+
prefix = ep.group.replace(group, "")
110+
eps[prefix + ep.name] = _defer_load_ep(ep)
100111
return eps

torchx/util/test/entrypoints_test.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ def barbaz() -> str:
8080
class EntryPointsTest(unittest.TestCase):
8181
@patch(_METADATA_EPS, return_value=_ENTRY_POINTS)
8282
def test_load(self, _: MagicMock) -> None:
83-
print(type(load("entrypoints.test", "foo")))
8483
self.assertEqual("foobar", load("entrypoints.test", "foo")())
8584

8685
with self.assertRaisesRegex(KeyError, "baz"):
@@ -127,6 +126,14 @@ def test_load_group_with_default(self, _: MagicMock) -> None:
127126
)
128127
self.assertIsNone(eps)
129128

129+
@patch(_METADATA_EPS, return_value=_ENTRY_POINTS)
130+
def test_load_group_with_prefix(self, _: MagicMock) -> None:
131+
eps = load_group("grp.test")
132+
assert eps
133+
self.assertEqual(2, len(eps))
134+
self.assertEqual("foobar", eps["ep.foo"]())
135+
self.assertEqual("barbaz", eps["ep.bar"]())
136+
130137
@patch(_METADATA_EPS, return_value=_ENTRY_POINTS)
131138
def test_load_group_missing(self, _: MagicMock) -> None:
132139
with self.assertRaises(AttributeError):

0 commit comments

Comments
 (0)