Skip to content

Commit 60883a9

Browse files
YgLKpre-commit-ci[bot]deependujhaBorda
authored
fix: remove extra parameter in accelerator registry decorator (#20975)
* fix: remove extra parameter in accelerator registry decorator * tests: add registry decorator tests * chlog --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deependu <[email protected]> Co-authored-by: Jirka B <[email protected]>
1 parent b7ec502 commit 60883a9

File tree

3 files changed

+109
-2
lines changed

3 files changed

+109
-2
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2121
- Raise ValueError when seed is `out-of-bounds` or `cannot be cast to int` ([#21029](https://github.com/Lightning-AI/pytorch-lightning/pull/21029))
2222

2323

24+
- fix: remove extra `name` parameter in accelerator registry decorator ([#20975](https://github.com/Lightning-AI/pytorch-lightning/pull/20975))
25+
26+
2427
---
2528

2629
## [2.5.2] - 2025-3-20

src/lightning/fabric/accelerators/registry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,14 @@ def register(
7373
data["description"] = description
7474
data["init_params"] = init_params
7575

76-
def do_register(name: str, accelerator: Callable) -> Callable:
76+
def do_register(accelerator: Callable) -> Callable:
7777
data["accelerator"] = accelerator
7878
data["accelerator_name"] = name
7979
self[name] = data
8080
return accelerator
8181

8282
if accelerator is not None:
83-
return do_register(name, accelerator)
83+
return do_register(accelerator)
8484

8585
return do_register
8686

tests/tests_fabric/accelerators/test_registry.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,38 @@
1616
import torch
1717

1818
from lightning.fabric.accelerators import ACCELERATOR_REGISTRY, Accelerator
19+
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
20+
21+
22+
class TestAccelerator(Accelerator):
23+
"""Helper accelerator class for testing."""
24+
25+
def __init__(self, param1=None, param2=None):
26+
self.param1 = param1
27+
self.param2 = param2
28+
super().__init__()
29+
30+
def setup_device(self, device: torch.device) -> None:
31+
pass
32+
33+
def teardown(self) -> None:
34+
pass
35+
36+
@staticmethod
37+
def parse_devices(devices):
38+
return devices
39+
40+
@staticmethod
41+
def get_parallel_devices(devices):
42+
return ["foo"] * devices
43+
44+
@staticmethod
45+
def auto_device_count():
46+
return 3
47+
48+
@staticmethod
49+
def is_available():
50+
return True
1951

2052

2153
def test_accelerator_registry_with_new_accelerator():
@@ -71,3 +103,75 @@ def is_available():
71103

72104
def test_available_accelerators_in_registry():
73105
assert ACCELERATOR_REGISTRY.available_accelerators() == {"cpu", "cuda", "mps", "tpu"}
106+
107+
108+
def test_registry_as_decorator():
109+
"""Test that the registry can be used as a decorator."""
110+
test_registry = _AcceleratorRegistry()
111+
112+
# Test decorator usage
113+
@test_registry.register("test_decorator", description="Test decorator accelerator", param1="value1", param2=42)
114+
class DecoratorAccelerator(TestAccelerator):
115+
pass
116+
117+
# Verify registration worked
118+
assert "test_decorator" in test_registry
119+
assert test_registry["test_decorator"]["description"] == "Test decorator accelerator"
120+
assert test_registry["test_decorator"]["init_params"] == {"param1": "value1", "param2": 42}
121+
assert test_registry["test_decorator"]["accelerator"] == DecoratorAccelerator
122+
assert test_registry["test_decorator"]["accelerator_name"] == "test_decorator"
123+
124+
# Test that we can instantiate the accelerator
125+
instance = test_registry.get("test_decorator")
126+
assert isinstance(instance, DecoratorAccelerator)
127+
assert instance.param1 == "value1"
128+
assert instance.param2 == 42
129+
130+
131+
def test_registry_as_static_method():
132+
"""Test that the registry can be used as a static method call."""
133+
test_registry = _AcceleratorRegistry()
134+
135+
class StaticMethodAccelerator(TestAccelerator):
136+
pass
137+
138+
# Test static method usage
139+
result = test_registry.register(
140+
"test_static",
141+
StaticMethodAccelerator,
142+
description="Test static method accelerator",
143+
param1="static_value",
144+
param2=100,
145+
)
146+
147+
# Verify registration worked
148+
assert "test_static" in test_registry
149+
assert test_registry["test_static"]["description"] == "Test static method accelerator"
150+
assert test_registry["test_static"]["init_params"] == {"param1": "static_value", "param2": 100}
151+
assert test_registry["test_static"]["accelerator"] == StaticMethodAccelerator
152+
assert test_registry["test_static"]["accelerator_name"] == "test_static"
153+
assert result == StaticMethodAccelerator # Should return the accelerator class
154+
155+
# Test that we can instantiate the accelerator
156+
instance = test_registry.get("test_static")
157+
assert isinstance(instance, StaticMethodAccelerator)
158+
assert instance.param1 == "static_value"
159+
assert instance.param2 == 100
160+
161+
162+
def test_registry_without_parameters():
163+
"""Test registration without init parameters."""
164+
test_registry = _AcceleratorRegistry()
165+
166+
class SimpleAccelerator(TestAccelerator):
167+
def __init__(self):
168+
super().__init__()
169+
170+
test_registry.register("simple", SimpleAccelerator, description="Simple accelerator")
171+
172+
assert "simple" in test_registry
173+
assert test_registry["simple"]["description"] == "Simple accelerator"
174+
assert test_registry["simple"]["init_params"] == {}
175+
176+
instance = test_registry.get("simple")
177+
assert isinstance(instance, SimpleAccelerator)

0 commit comments

Comments
 (0)