From 05924ebb5359e02b0cbfbb174f82db8763a53c1b Mon Sep 17 00:00:00 2001 From: Alexander Zhipa Date: Wed, 23 Apr 2025 11:24:38 -0400 Subject: [PATCH] feat: list all registered schedulers (#1009) --- torchx/schedulers/__init__.py | 15 ++-- torchx/schedulers/test/registry_test.py | 97 +++++++++++++++++++++++++ 2 files changed, 103 insertions(+), 9 deletions(-) diff --git a/torchx/schedulers/__init__.py b/torchx/schedulers/__init__.py index aa773ea54..cc1f4c962 100644 --- a/torchx/schedulers/__init__.py +++ b/torchx/schedulers/__init__.py @@ -48,15 +48,12 @@ def get_scheduler_factories( The first scheduler in the dictionary is used as the default scheduler. """ - - if skip_defaults: - default_schedulers = {} - else: - default_schedulers: dict[str, SchedulerFactory] = {} - for scheduler, path in DEFAULT_SCHEDULER_MODULES.items(): - default_schedulers[scheduler] = _defer_load_scheduler(path) - - return load_group(group, default=default_schedulers) + schedulers = load_group(group, default={}) + if not skip_defaults: + for name, path in DEFAULT_SCHEDULER_MODULES.items(): + if name not in schedulers: + schedulers[name] = _defer_load_scheduler(path) + return schedulers def get_default_scheduler_name() -> str: diff --git a/torchx/schedulers/test/registry_test.py b/torchx/schedulers/test/registry_test.py index e133aafcf..fd397b9b9 100644 --- a/torchx/schedulers/test/registry_test.py +++ b/torchx/schedulers/test/registry_test.py @@ -43,3 +43,100 @@ def test_get_local_schedulers(self, mock_load_group: MagicMock) -> None: for scheduler in schedulers.values(): self.assertEqual("test_session", scheduler.session_name) + + @patch("torchx.schedulers.load_group") + def test_custom_schedulers_merged(self, mock_load_group: MagicMock) -> None: + mock_scheduler = MagicMock() + mock_load_group.return_value = {"custom": mock_scheduler} + + factories = get_scheduler_factories() + + self.assertIn("custom", factories) + self.assertEqual(factories["custom"], mock_scheduler) + self.assertIn("local_docker", factories) + + @patch("torchx.schedulers.load_group") + def test_custom_scheduler_overrides_default( + self, mock_load_group: MagicMock + ) -> None: + mock_scheduler = MagicMock() + mock_load_group.return_value = {"local_docker": mock_scheduler} + + factories = get_scheduler_factories() + + self.assertEqual(factories["local_docker"], mock_scheduler) + + @patch("torchx.schedulers.load_group") + def test_skip_defaults_with_custom_schedulers( + self, mock_load_group: MagicMock + ) -> None: + mock_scheduler = MagicMock() + mock_load_group.return_value = {"custom": mock_scheduler} + + factories = get_scheduler_factories(skip_defaults=True) + + self.assertEqual(factories, {"custom": mock_scheduler}) + self.assertNotIn("local_docker", factories) + + @patch("torchx.schedulers.load_group") + def test_with_custom_schedulers_skip_defaults_false( + self, mock_load_group: MagicMock + ) -> None: + """with custom schedulers, skip_defaults=False returns both""" + mock_aws = MagicMock() + mock_custom = MagicMock() + mock_load_group.return_value = {"aws_batch": mock_aws, "custom_1": mock_custom} + + factories = get_scheduler_factories(skip_defaults=False) + + self.assertIn("aws_batch", factories) + self.assertIn("custom_1", factories) + self.assertIn("local_docker", factories) + self.assertIn("slurm", factories) + + @patch("torchx.schedulers.load_group") + def test_with_custom_schedulers_skip_defaults_true( + self, mock_load_group: MagicMock + ) -> None: + """with custom schedulers, skip_defaults=True returns only custom""" + mock_aws = MagicMock() + mock_custom = MagicMock() + mock_load_group.return_value = {"aws_batch": mock_aws, "custom_1": mock_custom} + + factories = get_scheduler_factories(skip_defaults=True) + + self.assertEqual(set(factories.keys()), {"aws_batch", "custom_1"}) + + @patch("torchx.schedulers.load_group") + def test_no_custom_schedulers_skip_defaults_false( + self, mock_load_group: MagicMock + ) -> None: + """no custom schedulers, skip_defaults=False returns defaults""" + mock_load_group.return_value = {} + + factories = get_scheduler_factories(skip_defaults=False) + + self.assertIn("local_docker", factories) + self.assertIn("slurm", factories) + + @patch("torchx.schedulers.load_group") + def test_no_custom_schedulers_skip_defaults_true( + self, mock_load_group: MagicMock + ) -> None: + """no custom schedulers, skip_defaults=True returns empty""" + mock_load_group.return_value = {} + + factories = get_scheduler_factories(skip_defaults=True) + + self.assertEqual(factories, {}) + + @patch("torchx.schedulers.load_group") + def test_custom_scheduler_is_default(self, mock_load_group: MagicMock) -> None: + """first custom scheduler becomes the default""" + mock_aws = MagicMock() + mock_custom = MagicMock() + mock_load_group.return_value = {"aws_batch": mock_aws, "custom_1": mock_custom} + + default_name = get_default_scheduler_name() + + self.assertIn(default_name, ["aws_batch", "custom_1"])