55# LICENSE file in the root directory of this source tree.
66
77# pyre-strict
8+ # pyre-ignore-all-errors[3, 2, 16]
89
10+ from importlib import metadata
11+ from importlib .metadata import EntryPoint
912from typing import Any , Dict , Optional
1013
11- import importlib_metadata as metadata
12- from importlib_metadata import EntryPoint
1314
14-
15- # pyre-ignore-all-errors[3, 2]
1615def load (group : str , name : str , default = None ):
1716 """
1817 Loads the entry point specified by
@@ -30,13 +29,34 @@ def load(group: str, name: str, default=None):
3029 raises an error.
3130 """
3231
33- entrypoints = metadata .entry_points ().select (group = group )
32+ # [note_on_entrypoints]
33+ # return type of importlib.metadata.entry_points() is different between python-3.9 and python-3.10
34+ # https://docs.python.org/3.9/library/importlib.metadata.html#importlib.metadata.entry_points
35+ # https://docs.python.org/3.10/library/importlib.metadata.html#importlib.metadata.entry_points
36+ if hasattr (metadata .entry_points (), "select" ):
37+ # python>=3.10
38+ entrypoints = metadata .entry_points ().select (group = group )
3439
35- if name not in entrypoints .names and default is not None :
36- return default
40+ if name not in entrypoints .names and default is not None :
41+ return default
42+
43+ ep = entrypoints [name ]
44+ return ep .load ()
3745
38- ep = entrypoints [name ]
39- return ep .load ()
46+ else :
47+ # python<3.10 (e.g. 3.9)
48+ # metadata.entry_points() returns dict[str, tuple[EntryPoint]] (not EntryPoints) in python-3.9
49+ entrypoints = metadata .entry_points ().get (group , ())
50+
51+ for ep in entrypoints :
52+ if ep .name == name :
53+ return ep .load ()
54+
55+ # [group].name not found
56+ if default is not None :
57+ return default
58+ else :
59+ raise KeyError (f"entrypoint { group } .{ name } not found" )
4060
4161
4262def _defer_load_ep (ep : EntryPoint ) -> object :
@@ -49,7 +69,6 @@ def run(*args: object, **kwargs: object) -> object:
4969 return run
5070
5171
52- # pyre-ignore-all-errors[3, 2]
5372def load_group (
5473 group : str , default : Optional [Dict [str , Any ]] = None , skip_defaults : bool = False
5574):
@@ -87,7 +106,13 @@ def load_group(
87106
88107 """
89108
90- entrypoints = metadata .entry_points ().select (group = group )
109+ # see [note_on_entrypoints] above
110+ if hasattr (metadata .entry_points (), "select" ):
111+ # python>=3.10
112+ entrypoints = metadata .entry_points ().select (group = group )
113+ else :
114+ # python<3.10 (e.g. 3.9)
115+ entrypoints = metadata .entry_points ().get (group , ())
91116
92117 if len (entrypoints ) == 0 :
93118 if skip_defaults :
0 commit comments