|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | # %% |
16 | | -from typing import Any, Literal, Optional, Type, Union |
| 16 | +import warnings |
| 17 | +from typing import Annotated, Any, Literal, Optional, Union |
17 | 18 |
|
18 | 19 | import numpy as np |
19 | 20 | from bidict import bidict |
20 | 21 | from pydantic import ( |
21 | 22 | BaseModel, |
22 | 23 | ConfigDict, |
| 24 | + Discriminator, |
23 | 25 | Field, |
24 | 26 | TypeAdapter, |
25 | 27 | model_validator, |
26 | 28 | ) |
27 | 29 |
|
| 30 | +######################################################################################## |
| 31 | + |
| 32 | +__all__ = ["GroupBase", "Dataset", "GroupRegistry"] |
| 33 | + |
| 34 | +######################################################################################## |
| 35 | + |
| 36 | + |
28 | 37 | # %% |
29 | 38 | mapping = bidict( |
30 | 39 | { |
|
39 | 48 | ) |
40 | 49 |
|
41 | 50 |
|
42 | | -class Group(BaseModel, extra="forbid"): |
| 51 | +class GroupBase(BaseModel, extra="forbid"): |
43 | 52 | """ |
44 | 53 | Schema representation for a group object within an HDF5 file. |
45 | 54 |
|
@@ -144,63 +153,40 @@ def validate_data_matches_shape_dtype(self): |
144 | 153 | return self |
145 | 154 |
|
146 | 155 |
|
147 | | -class GroupRegistry: |
148 | | - """Registry for managing group types dynamically""" |
| 156 | +class MetaGroupRegistry(type): |
| 157 | + def __new__(cls, clsname, superclasses, attributedict): |
| 158 | + attributedict["groups"] = dict() |
| 159 | + return super().__new__(cls, clsname, superclasses, attributedict) |
149 | 160 |
|
150 | | - _types: dict[str, Type[Group]] = {} |
151 | | - _union_cache = None |
| 161 | + def register(cls, group): |
| 162 | + if not issubclass(group, GroupBase): |
| 163 | + raise TypeError("You may only register subclasses of GroupBase.") |
152 | 164 |
|
153 | | - @classmethod |
154 | | - def register(cls, group_type: Type[Group]): |
155 | | - """Register a new group type""" |
156 | | - import warnings |
157 | | - |
158 | | - type_name = group_type.__name__ |
159 | | - |
160 | | - # Check if type is already registered |
161 | | - if type_name in cls._types: |
162 | | - existing_type = cls._types[type_name] |
163 | | - if existing_type is not group_type: # Different class with same name |
164 | | - warnings.warn( |
165 | | - f"Group type '{type_name}' is already registered. " |
166 | | - f"Overwriting {existing_type} with {group_type}.", |
167 | | - UserWarning, |
168 | | - stacklevel=2, |
169 | | - ) |
170 | | - |
171 | | - cls._types[type_name] = group_type |
172 | | - cls._union_cache = None # Invalidate cache |
| 165 | + if group.__name__ in cls.groups.keys(): |
| 166 | + warnings.warn( |
| 167 | + f"Overwriting previously registered `{group.__name__}` group of the same name.", |
| 168 | + UserWarning, |
| 169 | + stacklevel=2, |
| 170 | + ) |
173 | 171 |
|
174 | | - @classmethod |
175 | | - def get_union(cls): |
176 | | - """Get the current Union of all registered types""" |
177 | | - if cls._union_cache is None: |
178 | | - if not cls._types: |
179 | | - raise ValueError("No group types registered") |
| 172 | + cls.groups[group.__name__] = group |
180 | 173 |
|
181 | | - type_list = list(cls._types.values()) |
182 | | - if len(type_list) == 1: |
183 | | - cls._union_cache = type_list[0] |
184 | | - else: |
185 | | - cls._union_cache = Union[tuple(type_list)] |
| 174 | + def clear(cls): |
| 175 | + """Clear all registered types (useful for testing)""" |
| 176 | + cls.groups.clear() |
186 | 177 |
|
187 | | - return cls._union_cache |
| 178 | + @property |
| 179 | + def union(cls): |
| 180 | + """Get the current Union of all registered types""" |
| 181 | + return Annotated[ |
| 182 | + Union[tuple(cls.groups.values())], Discriminator(discriminator="class_") |
| 183 | + ] |
188 | 184 |
|
189 | | - @classmethod |
190 | | - def get_adapter(cls): |
| 185 | + @property |
| 186 | + def adapter(cls): |
191 | 187 | """Get TypeAdapter for current registered types""" |
192 | | - from typing import Annotated |
| 188 | + return TypeAdapter(cls.union) |
193 | 189 |
|
194 | | - union_type = cls.get_union() |
195 | | - return TypeAdapter(Annotated[union_type, Field(discriminator="class_")]) |
196 | 190 |
|
197 | | - @classmethod |
198 | | - def clear(cls): |
199 | | - """Clear all registered types (useful for testing)""" |
200 | | - cls._types.clear() |
201 | | - cls._union_cache = None |
202 | | - |
203 | | - @classmethod |
204 | | - def list_types(cls): |
205 | | - """List all registered type names""" |
206 | | - return list(cls._types.keys()) |
| 191 | +class GroupRegistry(metaclass=MetaGroupRegistry): |
| 192 | + pass |
0 commit comments