diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 18537ca15e2fc..2b1d651d1c027 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Raise ValueError when seed is `out-of-bounds` or `cannot be cast to int` ([#21029](https://github.com/Lightning-AI/pytorch-lightning/pull/21029)) +- fix: remove extra `name` parameter in accelerator registry decorator ([#20975](https://github.com/Lightning-AI/pytorch-lightning/pull/20975)) + + --- ## [2.5.2] - 2025-3-20 diff --git a/src/lightning/fabric/accelerators/registry.py b/src/lightning/fabric/accelerators/registry.py index 4959a0fb9426a..539b7aa8a01dc 100644 --- a/src/lightning/fabric/accelerators/registry.py +++ b/src/lightning/fabric/accelerators/registry.py @@ -73,14 +73,14 @@ def register( data["description"] = description data["init_params"] = init_params - def do_register(name: str, accelerator: Callable) -> Callable: + def do_register(accelerator: Callable) -> Callable: data["accelerator"] = accelerator data["accelerator_name"] = name self[name] = data return accelerator if accelerator is not None: - return do_register(name, accelerator) + return do_register(accelerator) return do_register diff --git a/tests/tests_fabric/accelerators/test_registry.py b/tests/tests_fabric/accelerators/test_registry.py index 8036a6f45b8a0..b88ecf1db1e57 100644 --- a/tests/tests_fabric/accelerators/test_registry.py +++ b/tests/tests_fabric/accelerators/test_registry.py @@ -16,6 +16,38 @@ import torch from lightning.fabric.accelerators import ACCELERATOR_REGISTRY, Accelerator +from lightning.fabric.accelerators.registry import _AcceleratorRegistry + + +class TestAccelerator(Accelerator): + """Helper accelerator class for testing.""" + + def __init__(self, param1=None, param2=None): + self.param1 = param1 + self.param2 = param2 + super().__init__() + + def setup_device(self, device: torch.device) -> None: + pass + + def teardown(self) -> None: + pass + + @staticmethod + def parse_devices(devices): + return devices + + @staticmethod + def get_parallel_devices(devices): + return ["foo"] * devices + + @staticmethod + def auto_device_count(): + return 3 + + @staticmethod + def is_available(): + return True def test_accelerator_registry_with_new_accelerator(): @@ -71,3 +103,75 @@ def is_available(): def test_available_accelerators_in_registry(): assert ACCELERATOR_REGISTRY.available_accelerators() == {"cpu", "cuda", "mps", "tpu"} + + +def test_registry_as_decorator(): + """Test that the registry can be used as a decorator.""" + test_registry = _AcceleratorRegistry() + + # Test decorator usage + @test_registry.register("test_decorator", description="Test decorator accelerator", param1="value1", param2=42) + class DecoratorAccelerator(TestAccelerator): + pass + + # Verify registration worked + assert "test_decorator" in test_registry + assert test_registry["test_decorator"]["description"] == "Test decorator accelerator" + assert test_registry["test_decorator"]["init_params"] == {"param1": "value1", "param2": 42} + assert test_registry["test_decorator"]["accelerator"] == DecoratorAccelerator + assert test_registry["test_decorator"]["accelerator_name"] == "test_decorator" + + # Test that we can instantiate the accelerator + instance = test_registry.get("test_decorator") + assert isinstance(instance, DecoratorAccelerator) + assert instance.param1 == "value1" + assert instance.param2 == 42 + + +def test_registry_as_static_method(): + """Test that the registry can be used as a static method call.""" + test_registry = _AcceleratorRegistry() + + class StaticMethodAccelerator(TestAccelerator): + pass + + # Test static method usage + result = test_registry.register( + "test_static", + StaticMethodAccelerator, + description="Test static method accelerator", + param1="static_value", + param2=100, + ) + + # Verify registration worked + assert "test_static" in test_registry + assert test_registry["test_static"]["description"] == "Test static method accelerator" + assert test_registry["test_static"]["init_params"] == {"param1": "static_value", "param2": 100} + assert test_registry["test_static"]["accelerator"] == StaticMethodAccelerator + assert test_registry["test_static"]["accelerator_name"] == "test_static" + assert result == StaticMethodAccelerator # Should return the accelerator class + + # Test that we can instantiate the accelerator + instance = test_registry.get("test_static") + assert isinstance(instance, StaticMethodAccelerator) + assert instance.param1 == "static_value" + assert instance.param2 == 100 + + +def test_registry_without_parameters(): + """Test registration without init parameters.""" + test_registry = _AcceleratorRegistry() + + class SimpleAccelerator(TestAccelerator): + def __init__(self): + super().__init__() + + test_registry.register("simple", SimpleAccelerator, description="Simple accelerator") + + assert "simple" in test_registry + assert test_registry["simple"]["description"] == "Simple accelerator" + assert test_registry["simple"]["init_params"] == {} + + instance = test_registry.get("simple") + assert isinstance(instance, SimpleAccelerator)