Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion torchx/runner/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,15 @@ def _configparser() -> configparser.ConfigParser:


def _get_scheduler(name: str) -> Scheduler:
schedulers = get_scheduler_factories()
schedulers = {
**get_scheduler_factories(),
**(
get_scheduler_factories(
group="torchx.schedulers.orchestrator", skip_defaults=True
)
or {}
),
}
if name not in schedulers:
raise ValueError(
f"`{name}` is not a registered scheduler. Valid scheduler names: {schedulers.keys()}"
Expand Down
18 changes: 15 additions & 3 deletions torchx/runner/test/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,10 +470,22 @@ def test_dump_and_load_all_registered_schedulers(self) -> None:
sfile = StringIO()
dump(sfile)

for sched_name, sched in get_scheduler_factories().items():
scheduler_factories = {
**get_scheduler_factories(),
**(
get_scheduler_factories(
group="torchx.schedulers.orchestrator", skip_defaults=True
)
or {}
),
}

for sched_name, sched in scheduler_factories.items():
sfile.seek(0) # reset the file pos
cfg = {}
load(scheduler=sched_name, f=sfile, cfg=cfg)

for opt_name, _ in sched("test").run_opts():
self.assertTrue(opt_name in cfg)
self.assertTrue(
opt_name in cfg,
f"missing {opt_name} in {sched} run opts with cfg {cfg}",
)
9 changes: 6 additions & 3 deletions torchx/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ def run(*args: object, **kwargs: object) -> Scheduler:
return run


def get_scheduler_factories() -> Dict[str, SchedulerFactory]:
def get_scheduler_factories(
group: str = "torchx.schedulers", skip_defaults: bool = False
) -> Dict[str, SchedulerFactory]:
"""
get_scheduler_factories returns all the available schedulers names and the
get_scheduler_factories returns all the available schedulers names under `group` and the
method to instantiate them.
The first scheduler in the dictionary is used as the default scheduler.
Expand All @@ -55,8 +57,9 @@ def get_scheduler_factories() -> Dict[str, SchedulerFactory]:
default_schedulers[scheduler] = _defer_load_scheduler(path)

return load_group(
"torchx.schedulers",
group,
default=default_schedulers,
skip_defaults=skip_defaults,
)


Expand Down
1 change: 1 addition & 0 deletions torchx/schedulers/test/registry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __call__(
group: str,
default: Dict[str, Any],
ignore_missing: Optional[bool] = False,
skip_defaults: bool = False,
) -> Dict[str, Any]:
return default

Expand Down
6 changes: 4 additions & 2 deletions torchx/util/entrypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ def run(*args: object, **kwargs: object) -> object:

# pyre-ignore-all-errors[3, 2]
def load_group(
group: str,
default: Optional[Dict[str, Any]] = None,
group: str, default: Optional[Dict[str, Any]] = None, skip_defaults: bool = False
):
"""
Loads all the entry points specified by ``group`` and returns
Expand All @@ -72,6 +71,7 @@ def load_group(
1. ``load_group("foo")["bar"]("baz")`` -> equivalent to calling ``this.is.a_fn("baz")``
1. ``load_group("food")`` -> ``None``
1. ``load_group("food", default={"hello": this.is.c_fn})["hello"]("world")`` -> equivalent to calling ``this.is.c_fn("world")``
1. ``load_group("food", default={"hello": this.is.c_fn}, skip_defaults=True)`` -> ``None``


If the entrypoint is a module (versus a function as shown above), then calling the ``deferred_load_fn``
Expand All @@ -90,6 +90,8 @@ def load_group(
entrypoints = metadata.entry_points().select(group=group)

if len(entrypoints) == 0:
if skip_defaults:
return None
return default

eps = {}
Expand Down
5 changes: 5 additions & 0 deletions torchx/util/test/entrypoints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ def test_load_group_with_default(self, _: MagicMock) -> None:
self.assertEqual("barbaz", eps["foo"]())
self.assertEqual("foobar", eps["bar"]())

eps = load_group(
"ep.grp.test.missing", {"foo": barbaz, "bar": foobar}, skip_defaults=True
)
self.assertIsNone(eps)

@patch(_METADATA_EPS, return_value=_ENTRY_POINTS)
def test_load_group_missing(self, _: MagicMock) -> None:
with self.assertRaises(AttributeError):
Expand Down
Loading