Skip to content

Commit cfa412b

Browse files
committed
feat: list all registered schedulers (#1009)
1 parent 5957532 commit cfa412b

File tree

2 files changed

+103
-9
lines changed

2 files changed

+103
-9
lines changed

torchx/schedulers/__init__.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,12 @@ def get_scheduler_factories(
4848
4949
The first scheduler in the dictionary is used as the default scheduler.
5050
"""
51-
52-
if skip_defaults:
53-
default_schedulers = {}
54-
else:
55-
default_schedulers: dict[str, SchedulerFactory] = {}
56-
for scheduler, path in DEFAULT_SCHEDULER_MODULES.items():
57-
default_schedulers[scheduler] = _defer_load_scheduler(path)
58-
59-
return load_group(group, default=default_schedulers)
51+
schedulers = load_group(group, default={})
52+
if not skip_defaults:
53+
for name, path in DEFAULT_SCHEDULER_MODULES.items():
54+
if name not in schedulers:
55+
schedulers[name] = _defer_load_scheduler(path)
56+
return schedulers
6057

6158

6259
def get_default_scheduler_name() -> str:

torchx/schedulers/test/registry_test.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,100 @@ def test_get_local_schedulers(self, mock_load_group: MagicMock) -> None:
4343

4444
for scheduler in schedulers.values():
4545
self.assertEqual("test_session", scheduler.session_name)
46+
47+
@patch("torchx.schedulers.load_group")
48+
def test_custom_schedulers_merged(self, mock_load_group: MagicMock) -> None:
49+
mock_scheduler = MagicMock()
50+
mock_load_group.return_value = {"custom": mock_scheduler}
51+
52+
factories = get_scheduler_factories()
53+
54+
self.assertIn("custom", factories)
55+
self.assertEqual(factories["custom"], mock_scheduler)
56+
self.assertIn("local_docker", factories)
57+
58+
@patch("torchx.schedulers.load_group")
59+
def test_custom_scheduler_overrides_default(
60+
self, mock_load_group: MagicMock
61+
) -> None:
62+
mock_scheduler = MagicMock()
63+
mock_load_group.return_value = {"local_docker": mock_scheduler}
64+
65+
factories = get_scheduler_factories()
66+
67+
self.assertEqual(factories["local_docker"], mock_scheduler)
68+
69+
@patch("torchx.schedulers.load_group")
70+
def test_skip_defaults_with_custom_schedulers(
71+
self, mock_load_group: MagicMock
72+
) -> None:
73+
mock_scheduler = MagicMock()
74+
mock_load_group.return_value = {"custom": mock_scheduler}
75+
76+
factories = get_scheduler_factories(skip_defaults=True)
77+
78+
self.assertEqual(factories, {"custom": mock_scheduler})
79+
self.assertNotIn("local_docker", factories)
80+
81+
@patch("torchx.schedulers.load_group")
82+
def test_with_custom_schedulers_skip_defaults_false(
83+
self, mock_load_group: MagicMock
84+
) -> None:
85+
"""with custom schedulers, skip_defaults=False returns both"""
86+
mock_aws = MagicMock()
87+
mock_custom = MagicMock()
88+
mock_load_group.return_value = {"aws_batch": mock_aws, "custom_1": mock_custom}
89+
90+
factories = get_scheduler_factories(skip_defaults=False)
91+
92+
self.assertIn("aws_batch", factories)
93+
self.assertIn("custom_1", factories)
94+
self.assertIn("local_docker", factories)
95+
self.assertIn("slurm", factories)
96+
97+
@patch("torchx.schedulers.load_group")
98+
def test_with_custom_schedulers_skip_defaults_true(
99+
self, mock_load_group: MagicMock
100+
) -> None:
101+
"""with custom schedulers, skip_defaults=True returns only custom"""
102+
mock_aws = MagicMock()
103+
mock_custom = MagicMock()
104+
mock_load_group.return_value = {"aws_batch": mock_aws, "custom_1": mock_custom}
105+
106+
factories = get_scheduler_factories(skip_defaults=True)
107+
108+
self.assertEqual(set(factories.keys()), {"aws_batch", "custom_1"})
109+
110+
@patch("torchx.schedulers.load_group")
111+
def test_no_custom_schedulers_skip_defaults_false(
112+
self, mock_load_group: MagicMock
113+
) -> None:
114+
"""no custom schedulers, skip_defaults=False returns defaults"""
115+
mock_load_group.return_value = {}
116+
117+
factories = get_scheduler_factories(skip_defaults=False)
118+
119+
self.assertIn("local_docker", factories)
120+
self.assertIn("slurm", factories)
121+
122+
@patch("torchx.schedulers.load_group")
123+
def test_no_custom_schedulers_skip_defaults_true(
124+
self, mock_load_group: MagicMock
125+
) -> None:
126+
"""no custom schedulers, skip_defaults=True returns empty"""
127+
mock_load_group.return_value = {}
128+
129+
factories = get_scheduler_factories(skip_defaults=True)
130+
131+
self.assertEqual(factories, {})
132+
133+
@patch("torchx.schedulers.load_group")
134+
def test_custom_scheduler_is_default(self, mock_load_group: MagicMock) -> None:
135+
"""first custom scheduler becomes the default"""
136+
mock_aws = MagicMock()
137+
mock_custom = MagicMock()
138+
mock_load_group.return_value = {"aws_batch": mock_aws, "custom_1": mock_custom}
139+
140+
default_name = get_default_scheduler_name()
141+
142+
self.assertIn(default_name, ["aws_batch", "custom_1"])

0 commit comments

Comments
 (0)