Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 178cbf5

Browse files
dbogunowiczbfineran
authored andcommitted
Allow passing multiple name for registering a value in RegistryMixin (#385)
1 parent f269c02 commit 178cbf5

File tree

2 files changed

+37
-14
lines changed

2 files changed

+37
-14
lines changed

src/sparsezoo/utils/registry.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import importlib
2121
from collections import defaultdict
22-
from typing import Any, Dict, List, Optional, Type
22+
from typing import Any, Dict, List, Optional, Type, Union
2323

2424

2525
__all__ = [
@@ -64,6 +64,11 @@ class ImageNetDataset(Dataset):
6464
class Cifar(Dataset):
6565
pass
6666
67+
# register with multiple aliases
68+
@Dataset.register(name=["cifar-10-dataset", "cifar-100-dataset"])
69+
class Cifar(Dataset):
70+
pass
71+
6772
# load as "cifar-dataset"
6873
cifar = Dataset.load_from_registry("cifar-dataset")
6974
@@ -77,12 +82,13 @@ class Cifar(Dataset):
7782
registry_requires_subclass: bool = False
7883

7984
@classmethod
80-
def register(cls, name: Optional[str] = None):
85+
def register(cls, name: Union[List[str], str, None] = None):
8186
"""
8287
Decorator for registering a value (ie class or function) wrapped by this
8388
decorator to the base class (class that .register is called from)
8489
85-
:param name: name to register the wrapped value as, defaults to value.__name__
90+
:param name: name or list of names to register the wrapped value as,
91+
defaults to value.__name__
8692
:return: register decorator
8793
"""
8894

@@ -93,18 +99,22 @@ def decorator(value: Any):
9399
return decorator
94100

95101
@classmethod
96-
def register_value(cls, value: Any, name: Optional[str] = None):
102+
def register_value(cls, value: Any, name: Union[List[str], str, None] = None):
97103
"""
98104
Registers the given value to the class `.register_value` is called from
99105
:param value: value to register
100-
:param name: name to register the wrapped value as, defaults to value.__name__
106+
:param name: name or list of names to register the wrapped value as,
107+
defaults to value.__name__
101108
"""
102-
register(
103-
parent_class=cls,
104-
value=value,
105-
name=name,
106-
require_subclass=cls.registry_requires_subclass,
107-
)
109+
names = name if isinstance(name, list) else [name]
110+
111+
for name in names:
112+
register(
113+
parent_class=cls,
114+
value=value,
115+
name=name,
116+
require_subclass=cls.registry_requires_subclass,
117+
)
108118

109119
@classmethod
110120
def load_from_registry(cls, name: str, **constructor_kwargs) -> object:
@@ -148,7 +158,7 @@ def register(
148158
):
149159
"""
150160
:param parent_class: class to register the name under
151-
:param value: value to register
161+
:param value: the value to register
152162
:param name: name to register the wrapped value as, defaults to value.__name__
153163
:param require_subclass: require that value is a subclass of the class this
154164
method is called from
@@ -193,7 +203,7 @@ def get_from_registry(
193203
# look up name in registry
194204
retrieved_value = _REGISTRY[parent_class].get(name)
195205
if retrieved_value is None:
196-
raise ValueError(
206+
raise KeyError(
197207
f"Unable to find {name} registered under type {parent_class}. "
198208
f"Registered values for {parent_class}: "
199209
f"{registered_names(parent_class)}"

tests/sparsezoo/utils/test_registry.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,30 @@ class Foo(RegistryMixin):
2525
class Foo1(Foo):
2626
pass
2727

28+
assert {"Foo1"} == set(Foo.registered_names())
29+
2830
@Foo.register(name="name_2")
2931
class Foo2(Foo):
3032
pass
3133

3234
assert {"Foo1", "name_2"} == set(Foo.registered_names())
3335

34-
with pytest.raises(ValueError):
36+
@Foo.register(name=["name_3", "name_4"])
37+
class Foo3(Foo):
38+
pass
39+
40+
assert {"Foo1", "name_2", "name_3", "name_4"} == set(Foo.registered_names())
41+
42+
with pytest.raises(KeyError):
3543
Foo.get_value_from_registry("Foo2")
3644

3745
assert Foo.get_value_from_registry("Foo1") is Foo1
3846
assert isinstance(Foo.load_from_registry("name_2"), Foo2)
47+
assert (
48+
Foo.get_value_from_registry("name_3")
49+
is Foo3
50+
is Foo.get_value_from_registry("name_4")
51+
)
3952

4053

4154
def test_registry_flow_multiple():

0 commit comments

Comments
 (0)