1919
2020import importlib
2121from collections import defaultdict
22- from typing import Any , Dict , List , Optional , Type
22+ from typing import Any , Dict , List , Optional , Type , Union
2323
2424
2525__all__ = [
@@ -64,6 +64,11 @@ class ImageNetDataset(Dataset):
6464 class Cifar(Dataset):
6565 pass
6666
67+ # register with multiple aliases
68+ @Dataset.register(name=["cifar-10-dataset", "cifar-100-dataset"])
69+ class Cifar(Dataset):
70+ pass
71+
6772 # load as "cifar-dataset"
6873 cifar = Dataset.load_from_registry("cifar-dataset")
6974
@@ -77,12 +82,13 @@ class Cifar(Dataset):
7782 registry_requires_subclass : bool = False
7883
7984 @classmethod
80- def register (cls , name : Optional [ str ] = None ):
85+ def register (cls , name : Union [ List [ str ], str , None ] = None ):
8186 """
8287 Decorator for registering a value (ie class or function) wrapped by this
8388 decorator to the base class (class that .register is called from)
8489
85- :param name: name to register the wrapped value as, defaults to value.__name__
90+ :param name: name or list of names to register the wrapped value as,
91+ defaults to value.__name__
8692 :return: register decorator
8793 """
8894
@@ -93,18 +99,22 @@ def decorator(value: Any):
9399 return decorator
94100
95101 @classmethod
96- def register_value (cls , value : Any , name : Optional [ str ] = None ):
102+ def register_value (cls , value : Any , name : Union [ List [ str ], str , None ] = None ):
97103 """
98104 Registers the given value to the class `.register_value` is called from
99105 :param value: value to register
100- :param name: name to register the wrapped value as, defaults to value.__name__
106+ :param name: name or list of names to register the wrapped value as,
107+ defaults to value.__name__
101108 """
102- register (
103- parent_class = cls ,
104- value = value ,
105- name = name ,
106- require_subclass = cls .registry_requires_subclass ,
107- )
109+ names = name if isinstance (name , list ) else [name ]
110+
111+ for name in names :
112+ register (
113+ parent_class = cls ,
114+ value = value ,
115+ name = name ,
116+ require_subclass = cls .registry_requires_subclass ,
117+ )
108118
109119 @classmethod
110120 def load_from_registry (cls , name : str , ** constructor_kwargs ) -> object :
@@ -148,7 +158,7 @@ def register(
148158):
149159 """
150160 :param parent_class: class to register the name under
151- :param value: value to register
161+ :param value: the value to register
152162 :param name: name to register the wrapped value as, defaults to value.__name__
153163 :param require_subclass: require that value is a subclass of the class this
154164 method is called from
@@ -193,7 +203,7 @@ def get_from_registry(
193203 # look up name in registry
194204 retrieved_value = _REGISTRY [parent_class ].get (name )
195205 if retrieved_value is None :
196- raise ValueError (
206+ raise KeyError (
197207 f"Unable to find { name } registered under type { parent_class } . "
198208 f"Registered values for { parent_class } : "
199209 f"{ registered_names (parent_class )} "
0 commit comments