|
16 | 16 | import torch
|
17 | 17 |
|
18 | 18 | from lightning.fabric.accelerators import ACCELERATOR_REGISTRY, Accelerator
|
| 19 | +from lightning.fabric.accelerators.registry import _AcceleratorRegistry |
| 20 | + |
| 21 | + |
| 22 | +class TestAccelerator(Accelerator): |
| 23 | + """Helper accelerator class for testing.""" |
| 24 | + |
| 25 | + def __init__(self, param1=None, param2=None): |
| 26 | + self.param1 = param1 |
| 27 | + self.param2 = param2 |
| 28 | + super().__init__() |
| 29 | + |
| 30 | + def setup_device(self, device: torch.device) -> None: |
| 31 | + pass |
| 32 | + |
| 33 | + def teardown(self) -> None: |
| 34 | + pass |
| 35 | + |
| 36 | + @staticmethod |
| 37 | + def parse_devices(devices): |
| 38 | + return devices |
| 39 | + |
| 40 | + @staticmethod |
| 41 | + def get_parallel_devices(devices): |
| 42 | + return ["foo"] * devices |
| 43 | + |
| 44 | + @staticmethod |
| 45 | + def auto_device_count(): |
| 46 | + return 3 |
| 47 | + |
| 48 | + @staticmethod |
| 49 | + def is_available(): |
| 50 | + return True |
19 | 51 |
|
20 | 52 |
|
21 | 53 | def test_accelerator_registry_with_new_accelerator():
|
@@ -71,3 +103,75 @@ def is_available():
|
71 | 103 |
|
72 | 104 | def test_available_accelerators_in_registry():
|
73 | 105 | assert ACCELERATOR_REGISTRY.available_accelerators() == {"cpu", "cuda", "mps", "tpu"}
|
| 106 | + |
| 107 | + |
| 108 | +def test_registry_as_decorator(): |
| 109 | + """Test that the registry can be used as a decorator.""" |
| 110 | + test_registry = _AcceleratorRegistry() |
| 111 | + |
| 112 | + # Test decorator usage |
| 113 | + @test_registry.register("test_decorator", description="Test decorator accelerator", param1="value1", param2=42) |
| 114 | + class DecoratorAccelerator(TestAccelerator): |
| 115 | + pass |
| 116 | + |
| 117 | + # Verify registration worked |
| 118 | + assert "test_decorator" in test_registry |
| 119 | + assert test_registry["test_decorator"]["description"] == "Test decorator accelerator" |
| 120 | + assert test_registry["test_decorator"]["init_params"] == {"param1": "value1", "param2": 42} |
| 121 | + assert test_registry["test_decorator"]["accelerator"] == DecoratorAccelerator |
| 122 | + assert test_registry["test_decorator"]["accelerator_name"] == "test_decorator" |
| 123 | + |
| 124 | + # Test that we can instantiate the accelerator |
| 125 | + instance = test_registry.get("test_decorator") |
| 126 | + assert isinstance(instance, DecoratorAccelerator) |
| 127 | + assert instance.param1 == "value1" |
| 128 | + assert instance.param2 == 42 |
| 129 | + |
| 130 | + |
| 131 | +def test_registry_as_static_method(): |
| 132 | + """Test that the registry can be used as a static method call.""" |
| 133 | + test_registry = _AcceleratorRegistry() |
| 134 | + |
| 135 | + class StaticMethodAccelerator(TestAccelerator): |
| 136 | + pass |
| 137 | + |
| 138 | + # Test static method usage |
| 139 | + result = test_registry.register( |
| 140 | + "test_static", |
| 141 | + StaticMethodAccelerator, |
| 142 | + description="Test static method accelerator", |
| 143 | + param1="static_value", |
| 144 | + param2=100, |
| 145 | + ) |
| 146 | + |
| 147 | + # Verify registration worked |
| 148 | + assert "test_static" in test_registry |
| 149 | + assert test_registry["test_static"]["description"] == "Test static method accelerator" |
| 150 | + assert test_registry["test_static"]["init_params"] == {"param1": "static_value", "param2": 100} |
| 151 | + assert test_registry["test_static"]["accelerator"] == StaticMethodAccelerator |
| 152 | + assert test_registry["test_static"]["accelerator_name"] == "test_static" |
| 153 | + assert result == StaticMethodAccelerator # Should return the accelerator class |
| 154 | + |
| 155 | + # Test that we can instantiate the accelerator |
| 156 | + instance = test_registry.get("test_static") |
| 157 | + assert isinstance(instance, StaticMethodAccelerator) |
| 158 | + assert instance.param1 == "static_value" |
| 159 | + assert instance.param2 == 100 |
| 160 | + |
| 161 | + |
| 162 | +def test_registry_without_parameters(): |
| 163 | + """Test registration without init parameters.""" |
| 164 | + test_registry = _AcceleratorRegistry() |
| 165 | + |
| 166 | + class SimpleAccelerator(TestAccelerator): |
| 167 | + def __init__(self): |
| 168 | + super().__init__() |
| 169 | + |
| 170 | + test_registry.register("simple", SimpleAccelerator, description="Simple accelerator") |
| 171 | + |
| 172 | + assert "simple" in test_registry |
| 173 | + assert test_registry["simple"]["description"] == "Simple accelerator" |
| 174 | + assert test_registry["simple"]["init_params"] == {} |
| 175 | + |
| 176 | + instance = test_registry.get("simple") |
| 177 | + assert isinstance(instance, SimpleAccelerator) |
0 commit comments