Skip to content

Commit 813643e

Browse files
lgarg26facebook-github-bot
authored andcommitted
allow configurable scheduler load group (#992)
Summary: Allow configurable scheduler load group for clean scheduler splits Reviewed By: jesszzzz Differential Revision: D67290464
1 parent c1a195a commit 813643e

File tree

4 files changed

+28
-9
lines changed

4 files changed

+28
-9
lines changed

torchx/runner/config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,12 @@ def _configparser() -> configparser.ConfigParser:
197197

198198

199199
def _get_scheduler(name: str) -> Scheduler:
200-
schedulers = get_scheduler_factories()
200+
schedulers = {
201+
**get_scheduler_factories(
202+
group="torchx.schedulers.orchestrator", skip_defaults=True
203+
),
204+
**get_scheduler_factories(),
205+
}
201206
if name not in schedulers:
202207
raise ValueError(
203208
f"`{name}` is not a registered scheduler. Valid scheduler names: {schedulers.keys()}"
@@ -248,6 +253,7 @@ def dump(
248253
try:
249254
sched = _get_scheduler(sched_name)
250255
except ModuleNotFoundError:
256+
print(f"[LG] no schedulule load for {sched_name}")
251257
continue
252258

253259
section = f"{sched_name}"

torchx/runner/test/config_test.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -470,10 +470,19 @@ def test_dump_and_load_all_registered_schedulers(self) -> None:
470470
sfile = StringIO()
471471
dump(sfile)
472472

473-
for sched_name, sched in get_scheduler_factories().items():
473+
scheduler_factories = {
474+
**get_scheduler_factories(),
475+
**get_scheduler_factories(
476+
group="torchx.schedulers.orchestrator", skip_defaults=True
477+
),
478+
}
479+
480+
for sched_name, sched in scheduler_factories.items():
474481
sfile.seek(0) # reset the file pos
475482
cfg = {}
476483
load(scheduler=sched_name, f=sfile, cfg=cfg)
477-
478484
for opt_name, _ in sched("test").run_opts():
479-
self.assertTrue(opt_name in cfg)
485+
self.assertTrue(
486+
opt_name in cfg,
487+
f"missing {opt_name} in {sched} run opts with cfg {cfg}",
488+
)

torchx/schedulers/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,11 @@ def run(*args: object, **kwargs: object) -> Scheduler:
4242
return run
4343

4444

45-
def get_scheduler_factories() -> Dict[str, SchedulerFactory]:
45+
def get_scheduler_factories(
46+
group: str = "torchx.schedulers", skip_defaults: bool = False
47+
) -> Dict[str, SchedulerFactory]:
4648
"""
47-
get_scheduler_factories returns all the available schedulers names and the
49+
get_scheduler_factories returns all the available schedulers names under `group` and the
4850
method to instantiate them.
4951
5052
The first scheduler in the dictionary is used as the default scheduler.
@@ -55,8 +57,9 @@ def get_scheduler_factories() -> Dict[str, SchedulerFactory]:
5557
default_schedulers[scheduler] = _defer_load_scheduler(path)
5658

5759
return load_group(
58-
"torchx.schedulers",
60+
group,
5961
default=default_schedulers,
62+
skip_defaults=skip_defaults,
6063
)
6164

6265

torchx/util/entrypoints.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ def run(*args: object, **kwargs: object) -> object:
5151

5252
# pyre-ignore-all-errors[3, 2]
5353
def load_group(
54-
group: str,
55-
default: Optional[Dict[str, Any]] = None,
54+
group: str, default: Optional[Dict[str, Any]] = None, skip_defaults: bool = False
5655
):
5756
"""
5857
Loads all the entry points specified by ``group`` and returns
@@ -90,6 +89,8 @@ def load_group(
9089
entrypoints = metadata.entry_points().select(group=group)
9190

9291
if len(entrypoints) == 0:
92+
if skip_defaults:
93+
return {}
9394
return default
9495

9596
eps = {}

0 commit comments

Comments
 (0)