Skip to content

Commit 7031222

Browse files
authored
(torchx/entrypoints) remove redundant skip_defaults flag from entrypoints.load()
Differential Revision: D83991870 Pull Request resolved: #1140
1 parent 1e3df20 commit 7031222

File tree

3 files changed

+9
-20
lines changed

3 files changed

+9
-20
lines changed

torchx/schedulers/__init__.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,14 @@ def get_scheduler_factories(
4949
The first scheduler in the dictionary is used as the default scheduler.
5050
"""
5151

52-
default_schedulers: dict[str, SchedulerFactory] = {}
53-
for scheduler, path in DEFAULT_SCHEDULER_MODULES.items():
54-
default_schedulers[scheduler] = _defer_load_scheduler(path)
55-
56-
return load_group(
57-
group,
58-
default=default_schedulers,
59-
skip_defaults=skip_defaults,
60-
)
52+
if skip_defaults:
53+
default_schedulers = {}
54+
else:
55+
default_schedulers: dict[str, SchedulerFactory] = {}
56+
for scheduler, path in DEFAULT_SCHEDULER_MODULES.items():
57+
default_schedulers[scheduler] = _defer_load_scheduler(path)
58+
59+
return load_group(group, default=default_schedulers)
6160

6261

6362
def get_default_scheduler_name() -> str:

torchx/util/entrypoints.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,7 @@ def run(*args: object, **kwargs: object) -> object:
6969
return run
7070

7171

72-
def load_group(
73-
group: str, default: Optional[Dict[str, Any]] = None, skip_defaults: bool = False
74-
):
72+
def load_group(group: str, default: Optional[Dict[str, Any]] = None):
7573
"""
7674
Loads all the entry points specified by ``group`` and returns
7775
the entry points as a map of ``name (str) -> deferred_load_fn``.
@@ -90,7 +88,6 @@ def load_group(
9088
1. ``load_group("foo")["bar"]("baz")`` -> equivalent to calling ``this.is.a_fn("baz")``
9189
1. ``load_group("food")`` -> ``None``
9290
1. ``load_group("food", default={"hello": this.is.c_fn})["hello"]("world")`` -> equivalent to calling ``this.is.c_fn("world")``
93-
1. ``load_group("food", default={"hello": this.is.c_fn}, skip_defaults=True)`` -> ``None``
9491
9592
9693
If the entrypoint is a module (versus a function as shown above), then calling the ``deferred_load_fn``
@@ -115,8 +112,6 @@ def load_group(
115112
entrypoints = metadata.entry_points().get(group, ())
116113

117114
if len(entrypoints) == 0:
118-
if skip_defaults:
119-
return None
120115
return default
121116

122117
eps = {}

torchx/util/test/entrypoints_test.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,6 @@ def test_load_group_with_default(self, _: MagicMock) -> None:
134134
self.assertEqual("barbaz", eps["foo"]())
135135
self.assertEqual("foobar", eps["bar"]())
136136

137-
eps = load_group(
138-
"ep.grp.test.missing", {"foo": barbaz, "bar": foobar}, skip_defaults=True
139-
)
140-
self.assertIsNone(eps)
141-
142137
@patch(_METADATA_EPS, return_value=_ENTRY_POINTS)
143138
def test_load_group_missing(self, _: MagicMock) -> None:
144139
with self.assertRaises(AttributeError):

0 commit comments

Comments
 (0)