Skip to content

Commit e82792d

Browse files
Check if group in registry, union_cache, and forbid extra
1 parent e633f4e commit e82792d

File tree

3 files changed

+82
-27
lines changed

3 files changed

+82
-27
lines changed

examples/custom_group.ipynb

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 8,
5+
"execution_count": 15,
66
"metadata": {},
77
"outputs": [],
88
"source": [
@@ -11,7 +11,7 @@
1111
"import numpy as np\n",
1212
"from rich.pretty import pprint\n",
1313
"\n",
14-
"from oqd_dataschema.base import Dataset, Group\n",
14+
"from oqd_dataschema.base import Dataset, Group, GroupRegistry\n",
1515
"from oqd_dataschema.datastore import Datastore\n",
1616
"from oqd_dataschema.groups import (\n",
1717
" SinaraRawDataGroup,\n",
@@ -20,9 +20,18 @@
2020
},
2121
{
2222
"cell_type": "code",
23-
"execution_count": 2,
23+
"execution_count": null,
2424
"metadata": {},
25-
"outputs": [],
25+
"outputs": [
26+
{
27+
"name": "stderr",
28+
"output_type": "stream",
29+
"text": [
30+
"/Users/benjamin/Desktop/1 - Projects/Open Quantum Design/repos/oqd-dataschema/src/oqd_dataschema/base.py:66: UserWarning: Group type 'YourCustomGroup' is already registered. Overwriting <class '__main__.YourCustomGroup'> with <class '__main__.YourCustomGroup'>.\n",
31+
" GroupRegistry.register(cls)\n"
32+
]
33+
}
34+
],
2635
"source": [
2736
"class YourCustomGroup(Group):\n",
2837
" \"\"\"\n",
@@ -34,7 +43,31 @@
3443
},
3544
{
3645
"cell_type": "code",
37-
"execution_count": 5,
46+
"execution_count": 18,
47+
"metadata": {},
48+
"outputs": [
49+
{
50+
"data": {
51+
"text/plain": [
52+
"['SinaraRawDataGroup',\n",
53+
" 'MeasurementOutcomesDataGroup',\n",
54+
" 'ExpectationValueDataGroup',\n",
55+
" 'OQDTestbenchDataGroup',\n",
56+
" 'YourCustomGroup']"
57+
]
58+
},
59+
"execution_count": 18,
60+
"metadata": {},
61+
"output_type": "execute_result"
62+
}
63+
],
64+
"source": [
65+
"GroupRegistry.list_types()"
66+
]
67+
},
68+
{
69+
"cell_type": "code",
70+
"execution_count": 11,
3871
"metadata": {},
3972
"outputs": [],
4073
"source": [
@@ -49,7 +82,7 @@
4982
},
5083
{
5184
"cell_type": "code",
52-
"execution_count": 9,
85+
"execution_count": 13,
5386
"metadata": {},
5487
"outputs": [
5588
{
@@ -160,7 +193,7 @@
160193
},
161194
{
162195
"cell_type": "code",
163-
"execution_count": 11,
196+
"execution_count": 14,
164197
"metadata": {},
165198
"outputs": [
166199
{
@@ -263,6 +296,13 @@
263296
"parse = Datastore.model_validate_hdf5(filepath)\n",
264297
"pprint(parse)"
265298
]
299+
},
300+
{
301+
"cell_type": "code",
302+
"execution_count": null,
303+
"metadata": {},
304+
"outputs": [],
305+
"source": []
266306
}
267307
],
268308
"metadata": {

src/oqd_dataschema/base.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
# %%
16-
from typing import Any, Literal, Optional, Type, Union
16+
from typing import Any, Dict, Literal, Optional, Type, Union
1717

1818
import numpy as np
1919
from bidict import bidict
@@ -39,7 +39,7 @@
3939
)
4040

4141

42-
class Group(BaseModel):
42+
class Group(BaseModel, extra="forbid"):
4343
"""
4444
Schema representation for a group object within an HDF5 file.
4545
@@ -75,7 +75,7 @@ def auto_assign_class(cls, data):
7575
return data
7676

7777

78-
class Dataset(BaseModel):
78+
class Dataset(BaseModel, extra="forbid"):
7979
"""
8080
Schema representation for a dataset object to be saved within an HDF5 file.
8181
@@ -147,29 +147,45 @@ def validate_data_matches_shape_dtype(self):
147147
class GroupRegistry:
148148
"""Registry for managing group types dynamically"""
149149

150-
_types: dict[str, Type[Group]] = {}
151-
_union_cache = None
150+
_types: Dict[str, Type[Group]] = {}
152151

153152
@classmethod
154153
def register(cls, group_type: Type[Group]):
155154
"""Register a new group type"""
156-
cls._types[group_type.__name__] = group_type
157-
cls._union_cache = None # Invalidate cache
155+
import warnings
156+
157+
type_name = group_type.__name__
158+
159+
# Check if type is already registered
160+
if type_name in cls._types:
161+
existing_type = cls._types[type_name]
162+
if existing_type is not group_type: # Different class with same name
163+
warnings.warn(
164+
f"Group type '{type_name}' is already registered. "
165+
f"Overwriting {existing_type} with {group_type}.",
166+
UserWarning,
167+
stacklevel=2,
168+
)
169+
170+
cls._types[type_name] = group_type
171+
172+
@classmethod
173+
@property
174+
def union_cache(cls):
175+
"""Get the current Union of all registered types (computed on demand)"""
176+
if not cls._types:
177+
raise ValueError("No group types registered")
178+
179+
type_list = list(cls._types.values())
180+
if len(type_list) == 1:
181+
return type_list[0]
182+
else:
183+
return Union[tuple(type_list)]
158184

159185
@classmethod
160186
def get_union(cls):
161187
"""Get the current Union of all registered types"""
162-
if cls._union_cache is None:
163-
if not cls._types:
164-
raise ValueError("No group types registered")
165-
166-
type_list = list(cls._types.values())
167-
if len(type_list) == 1:
168-
cls._union_cache = type_list[0]
169-
else:
170-
cls._union_cache = Union[tuple(type_list)]
171-
172-
return cls._union_cache
188+
return cls.union_cache
173189

174190
@classmethod
175191
def get_adapter(cls):
@@ -183,7 +199,6 @@ def get_adapter(cls):
183199
def clear(cls):
184200
"""Clear all registered types (useful for testing)"""
185201
cls._types.clear()
186-
cls._union_cache = None
187202

188203
@classmethod
189204
def list_types(cls):

src/oqd_dataschema/datastore.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727

2828
# %%
29-
class Datastore(BaseModel):
29+
class Datastore(BaseModel, extra="forbid"):
3030
"""
3131
Saves the model and its associated data to an HDF5 file.
3232
This method serializes the model's data and attributes into an HDF5 file

0 commit comments

Comments
 (0)