Skip to content

fix: remove extra parameter in accelerator registry decorator #20975

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Raise ValueError when seed is `out-of-bounds` or `cannot be cast to int` ([#21029](https://github.com/Lightning-AI/pytorch-lightning/pull/21029))


- fix: remove extra `name` parameter in accelerator registry decorator ([#20975](https://github.com/Lightning-AI/pytorch-lightning/pull/20975))


---

## [2.5.2] - 2025-3-20
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/fabric/accelerators/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ def register(
data["description"] = description
data["init_params"] = init_params

def do_register(name: str, accelerator: Callable) -> Callable:
def do_register(accelerator: Callable) -> Callable:
data["accelerator"] = accelerator
data["accelerator_name"] = name
self[name] = data
return accelerator

if accelerator is not None:
return do_register(name, accelerator)
return do_register(accelerator)

return do_register

Expand Down
104 changes: 104 additions & 0 deletions tests/tests_fabric/accelerators/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,38 @@
import torch

from lightning.fabric.accelerators import ACCELERATOR_REGISTRY, Accelerator
from lightning.fabric.accelerators.registry import _AcceleratorRegistry


class TestAccelerator(Accelerator):
"""Helper accelerator class for testing."""

def __init__(self, param1=None, param2=None):
self.param1 = param1
self.param2 = param2
super().__init__()

def setup_device(self, device: torch.device) -> None:
pass

def teardown(self) -> None:
pass

@staticmethod
def parse_devices(devices):
return devices

@staticmethod
def get_parallel_devices(devices):
return ["foo"] * devices

@staticmethod
def auto_device_count():
return 3

@staticmethod
def is_available():
return True


def test_accelerator_registry_with_new_accelerator():
Expand Down Expand Up @@ -71,3 +103,75 @@ def is_available():

def test_available_accelerators_in_registry():
assert ACCELERATOR_REGISTRY.available_accelerators() == {"cpu", "cuda", "mps", "tpu"}


def test_registry_as_decorator():
"""Test that the registry can be used as a decorator."""
test_registry = _AcceleratorRegistry()

# Test decorator usage
@test_registry.register("test_decorator", description="Test decorator accelerator", param1="value1", param2=42)
class DecoratorAccelerator(TestAccelerator):
pass

# Verify registration worked
assert "test_decorator" in test_registry
assert test_registry["test_decorator"]["description"] == "Test decorator accelerator"
assert test_registry["test_decorator"]["init_params"] == {"param1": "value1", "param2": 42}
assert test_registry["test_decorator"]["accelerator"] == DecoratorAccelerator
assert test_registry["test_decorator"]["accelerator_name"] == "test_decorator"

# Test that we can instantiate the accelerator
instance = test_registry.get("test_decorator")
assert isinstance(instance, DecoratorAccelerator)
assert instance.param1 == "value1"
assert instance.param2 == 42


def test_registry_as_static_method():
"""Test that the registry can be used as a static method call."""
test_registry = _AcceleratorRegistry()

class StaticMethodAccelerator(TestAccelerator):
pass

# Test static method usage
result = test_registry.register(
"test_static",
StaticMethodAccelerator,
description="Test static method accelerator",
param1="static_value",
param2=100,
)

# Verify registration worked
assert "test_static" in test_registry
assert test_registry["test_static"]["description"] == "Test static method accelerator"
assert test_registry["test_static"]["init_params"] == {"param1": "static_value", "param2": 100}
assert test_registry["test_static"]["accelerator"] == StaticMethodAccelerator
assert test_registry["test_static"]["accelerator_name"] == "test_static"
assert result == StaticMethodAccelerator # Should return the accelerator class

# Test that we can instantiate the accelerator
instance = test_registry.get("test_static")
assert isinstance(instance, StaticMethodAccelerator)
assert instance.param1 == "static_value"
assert instance.param2 == 100


def test_registry_without_parameters():
"""Test registration without init parameters."""
test_registry = _AcceleratorRegistry()

class SimpleAccelerator(TestAccelerator):
def __init__(self):
super().__init__()

test_registry.register("simple", SimpleAccelerator, description="Simple accelerator")

assert "simple" in test_registry
assert test_registry["simple"]["description"] == "Simple accelerator"
assert test_registry["simple"]["init_params"] == {}

instance = test_registry.get("simple")
assert isinstance(instance, SimpleAccelerator)
Loading