Skip to content

Commit 57ec248

Browse files
committed
feat: list all registered schedulers (#1009)
1 parent 3cb18e3 commit 57ec248

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

torchx/schedulers/__init__.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,16 @@ def get_scheduler_factories(
5151
5252
The first scheduler in the dictionary is used as the default scheduler.
5353
"""
54+
valid_schedulers: Dict[str, SchedulerFactory] = {}
55+
if not skip_defaults:
56+
for scheduler_name, path in DEFAULT_SCHEDULER_MODULES.items():
57+
valid_schedulers[scheduler_name] = _defer_load_scheduler(path)
5458

55-
default_schedulers: Dict[str, SchedulerFactory] = {}
56-
for scheduler, path in DEFAULT_SCHEDULER_MODULES.items():
57-
default_schedulers[scheduler] = _defer_load_scheduler(path)
59+
entry_point_schedulers = load_group(group, default=None, skip_defaults=True)
60+
if entry_point_schedulers:
61+
valid_schedulers.update(entry_point_schedulers)
5862

59-
return load_group(
60-
group,
61-
default=default_schedulers,
62-
skip_defaults=skip_defaults,
63-
)
63+
return valid_schedulers
6464

6565

6666
def get_default_scheduler_name() -> str:

0 commit comments

Comments
 (0)