diff --git a/src/lightning/fabric/accelerators/registry.py b/src/lightning/fabric/accelerators/registry.py index 17d5233336d50..4959a0fb9426a 100644 --- a/src/lightning/fabric/accelerators/registry.py +++ b/src/lightning/fabric/accelerators/registry.py @@ -107,9 +107,9 @@ def remove(self, name: str) -> None: """Removes the registered accelerator by name.""" self.pop(name) - def available_accelerators(self) -> list[str]: - """Returns a list of registered accelerators.""" - return list(self.keys()) + def available_accelerators(self) -> set[str]: + """Returns a set of registered accelerators.""" + return set(self.keys()) def __str__(self) -> str: return "Registered Accelerators: {}".format(", ".join(self.available_accelerators())) diff --git a/tests/tests_fabric/accelerators/test_registry.py b/tests/tests_fabric/accelerators/test_registry.py index 28bfbb8ffd97c..8036a6f45b8a0 100644 --- a/tests/tests_fabric/accelerators/test_registry.py +++ b/tests/tests_fabric/accelerators/test_registry.py @@ -70,4 +70,4 @@ def is_available(): def test_available_accelerators_in_registry(): - assert ACCELERATOR_REGISTRY.available_accelerators() == ["cpu", "cuda", "mps", "tpu"] + assert ACCELERATOR_REGISTRY.available_accelerators() == {"cpu", "cuda", "mps", "tpu"} diff --git a/tests/tests_pytorch/accelerators/test_registry.py b/tests/tests_pytorch/accelerators/test_registry.py index 1c4358fea9696..8b29c9e937247 100644 --- a/tests/tests_pytorch/accelerators/test_registry.py +++ b/tests/tests_pytorch/accelerators/test_registry.py @@ -16,7 +16,7 @@ def test_available_accelerators_in_registry(): """Tests the accelerators available by default, not including external, third-party accelerators.""" - available = set(AcceleratorRegistry.available_accelerators()) + available = AcceleratorRegistry.available_accelerators() expected = {"cpu", "cuda", "mps", "tpu"} # Note: the registry is global, other tests may register new strategies as a side effect assert expected.issubset(available)