Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions torchx/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,12 @@ def get_scheduler_factories(

The first scheduler in the dictionary is used as the default scheduler.
"""

if skip_defaults:
default_schedulers = {}
else:
default_schedulers: dict[str, SchedulerFactory] = {}
for scheduler, path in DEFAULT_SCHEDULER_MODULES.items():
default_schedulers[scheduler] = _defer_load_scheduler(path)

return load_group(group, default=default_schedulers)
schedulers = load_group(group, default={})
if not skip_defaults:
for name, path in DEFAULT_SCHEDULER_MODULES.items():
if name not in schedulers:
schedulers[name] = _defer_load_scheduler(path)
return schedulers
Comment on lines +51 to +56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for addressing one of the BC incompatibilities. But I think the other one still remains. get_scheduler_factories(skip_defaults=False) now returns BOTH my custom [torchx.schedulers] mappings AND the DEFAULT_SCHEDULER_MODULES which is not ideal for users who expect ONLY my custom schedulers to show. This is important since the default schedulers may not even be available for the users to use.

Unlike named_resources, scheduler_factories is not additive. We should've kept the same behavior in both (either additive or not).

skip_defaults argument in get_scheduler_factories() means "even if you don't find anything in entry-points, do not return schedulers in DEFAULT_SCHEDULER_MODULES. I think the confusion is that DEFAULT_SCHEDULER_MODULES is a badly worded constant. What it "should've been called are BUILTIN_SCHEDULER_MODULES (aka the ones that torchx has "builtin" support for) and skip_defaults should've been called skip_builtins.

If the motivation here is for the CLI to expose custom + builtins such that builtins is always in-sync with any new schedulers added to torchx, we are currently in a mode where we are retiring (not adding) schedulers. The goal being better support/integration rather than wider coverage.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The goal of this PR is to be able to allow registering custom schedulers next to built-in ones. E.g. if NeMo wants to register some new ones or even overwrite some of them - the rest should be available, e.g. local_cwd. It's not practical to ask the user to adding all built-ins manually - we are not even sure in what order these will be resolved, e.g. both NeMo and some other package registering their own. The simplest idea was to allow additive behavior if we are not using the same names. I guess we can preserve the existing "clear+add" approach as well @kiukchung

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about torchx.schedulers.extra or torchx.schedulers.defaults as an alternative entry-point for additive behavior? @kiukchung
This keeps the existing behavior and makes it possible to add on top of builtins.
Please let me know if you want me to rename default to builtins (and where) - I can include this.



def get_default_scheduler_name() -> str:
Expand Down
97 changes: 97 additions & 0 deletions torchx/schedulers/test/registry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,100 @@ def test_get_local_schedulers(self, mock_load_group: MagicMock) -> None:

for scheduler in schedulers.values():
self.assertEqual("test_session", scheduler.session_name)

@patch("torchx.schedulers.load_group")
def test_custom_schedulers_merged(self, mock_load_group: MagicMock) -> None:
mock_scheduler = MagicMock()
mock_load_group.return_value = {"custom": mock_scheduler}

factories = get_scheduler_factories()

self.assertIn("custom", factories)
self.assertEqual(factories["custom"], mock_scheduler)
self.assertIn("local_docker", factories)

@patch("torchx.schedulers.load_group")
def test_custom_scheduler_overrides_default(
self, mock_load_group: MagicMock
) -> None:
mock_scheduler = MagicMock()
mock_load_group.return_value = {"local_docker": mock_scheduler}

factories = get_scheduler_factories()

self.assertEqual(factories["local_docker"], mock_scheduler)

@patch("torchx.schedulers.load_group")
def test_skip_defaults_with_custom_schedulers(
self, mock_load_group: MagicMock
) -> None:
mock_scheduler = MagicMock()
mock_load_group.return_value = {"custom": mock_scheduler}

factories = get_scheduler_factories(skip_defaults=True)

self.assertEqual(factories, {"custom": mock_scheduler})
self.assertNotIn("local_docker", factories)

@patch("torchx.schedulers.load_group")
def test_with_custom_schedulers_skip_defaults_false(
self, mock_load_group: MagicMock
) -> None:
"""with custom schedulers, skip_defaults=False returns both"""
mock_aws = MagicMock()
mock_custom = MagicMock()
mock_load_group.return_value = {"aws_batch": mock_aws, "custom_1": mock_custom}

factories = get_scheduler_factories(skip_defaults=False)

self.assertIn("aws_batch", factories)
self.assertIn("custom_1", factories)
self.assertIn("local_docker", factories)
self.assertIn("slurm", factories)

@patch("torchx.schedulers.load_group")
def test_with_custom_schedulers_skip_defaults_true(
self, mock_load_group: MagicMock
) -> None:
"""with custom schedulers, skip_defaults=True returns only custom"""
mock_aws = MagicMock()
mock_custom = MagicMock()
mock_load_group.return_value = {"aws_batch": mock_aws, "custom_1": mock_custom}

factories = get_scheduler_factories(skip_defaults=True)

self.assertEqual(set(factories.keys()), {"aws_batch", "custom_1"})

@patch("torchx.schedulers.load_group")
def test_no_custom_schedulers_skip_defaults_false(
self, mock_load_group: MagicMock
) -> None:
"""no custom schedulers, skip_defaults=False returns defaults"""
mock_load_group.return_value = {}

factories = get_scheduler_factories(skip_defaults=False)

self.assertIn("local_docker", factories)
self.assertIn("slurm", factories)

@patch("torchx.schedulers.load_group")
def test_no_custom_schedulers_skip_defaults_true(
self, mock_load_group: MagicMock
) -> None:
"""no custom schedulers, skip_defaults=True returns empty"""
mock_load_group.return_value = {}

factories = get_scheduler_factories(skip_defaults=True)

self.assertEqual(factories, {})

@patch("torchx.schedulers.load_group")
def test_custom_scheduler_is_default(self, mock_load_group: MagicMock) -> None:
"""first custom scheduler becomes the default"""
mock_aws = MagicMock()
mock_custom = MagicMock()
mock_load_group.return_value = {"aws_batch": mock_aws, "custom_1": mock_custom}

default_name = get_default_scheduler_name()

self.assertIn(default_name, ["aws_batch", "custom_1"])
Loading