Skip to content

Commit 80bbdc1

Browse files
authored
Merge pull request #7 from OpenQuantumDesign/refactor_group_registry
Refactor group registry
2 parents c17c1a5 + e15b5e3 commit 80bbdc1

File tree

10 files changed

+2444
-250
lines changed

10 files changed

+2444
-250
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,4 +172,5 @@ cython_debug/
172172

173173
*DS_Store
174174
*.h5
175-
*.code-workspace
175+
*.code-workspace
176+
.pre-commit-config.yaml

examples/custom_group.ipynb

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 15,
5+
"execution_count": 1,
66
"metadata": {},
77
"outputs": [],
88
"source": [
@@ -11,7 +11,7 @@
1111
"import numpy as np\n",
1212
"from rich.pretty import pprint\n",
1313
"\n",
14-
"from oqd_dataschema.base import Dataset, Group, GroupRegistry\n",
14+
"from oqd_dataschema.base import Dataset, GroupBase, GroupRegistry\n",
1515
"from oqd_dataschema.datastore import Datastore\n",
1616
"from oqd_dataschema.groups import (\n",
1717
" SinaraRawDataGroup,\n",
@@ -20,20 +20,11 @@
2020
},
2121
{
2222
"cell_type": "code",
23-
"execution_count": null,
23+
"execution_count": 2,
2424
"metadata": {},
25-
"outputs": [
26-
{
27-
"name": "stderr",
28-
"output_type": "stream",
29-
"text": [
30-
"/Users/benjamin/Desktop/1 - Projects/Open Quantum Design/repos/oqd-dataschema/src/oqd_dataschema/base.py:66: UserWarning: Group type 'YourCustomGroup' is already registered. Overwriting <class '__main__.YourCustomGroup'> with <class '__main__.YourCustomGroup'>.\n",
31-
" GroupRegistry.register(cls)\n"
32-
]
33-
}
34-
],
25+
"outputs": [],
3526
"source": [
36-
"class YourCustomGroup(Group):\n",
27+
"class YourCustomGroup(GroupBase):\n",
3728
" \"\"\"\n",
3829
" Here we define a custom Group, which is automatically added at runtime to the GroupRegistry.\n",
3930
" \"\"\"\n",
@@ -43,31 +34,31 @@
4334
},
4435
{
4536
"cell_type": "code",
46-
"execution_count": 18,
37+
"execution_count": 3,
4738
"metadata": {},
4839
"outputs": [
4940
{
5041
"data": {
5142
"text/plain": [
52-
"['SinaraRawDataGroup',\n",
53-
" 'MeasurementOutcomesDataGroup',\n",
54-
" 'ExpectationValueDataGroup',\n",
55-
" 'OQDTestbenchDataGroup',\n",
56-
" 'YourCustomGroup']"
43+
"{'SinaraRawDataGroup': oqd_dataschema.groups.SinaraRawDataGroup,\n",
44+
" 'MeasurementOutcomesDataGroup': oqd_dataschema.groups.MeasurementOutcomesDataGroup,\n",
45+
" 'ExpectationValueDataGroup': oqd_dataschema.groups.ExpectationValueDataGroup,\n",
46+
" 'OQDTestbenchDataGroup': oqd_dataschema.groups.OQDTestbenchDataGroup,\n",
47+
" 'YourCustomGroup': __main__.YourCustomGroup}"
5748
]
5849
},
59-
"execution_count": 18,
50+
"execution_count": 3,
6051
"metadata": {},
6152
"output_type": "execute_result"
6253
}
6354
],
6455
"source": [
65-
"GroupRegistry.list_types()"
56+
"GroupRegistry.groups"
6657
]
6758
},
6859
{
6960
"cell_type": "code",
70-
"execution_count": 11,
61+
"execution_count": 4,
7162
"metadata": {},
7263
"outputs": [],
7364
"source": [
@@ -82,7 +73,7 @@
8273
},
8374
{
8475
"cell_type": "code",
85-
"execution_count": 13,
76+
"execution_count": 5,
8677
"metadata": {},
8778
"outputs": [
8879
{
@@ -193,7 +184,7 @@
193184
},
194185
{
195186
"cell_type": "code",
196-
"execution_count": 14,
187+
"execution_count": 6,
197188
"metadata": {},
198189
"outputs": [
199190
{
@@ -321,7 +312,7 @@
321312
"name": "python",
322313
"nbconvert_exporter": "python",
323314
"pygments_lexer": "ipython3",
324-
"version": "3.12.1"
315+
"version": "3.13.2"
325316
}
326317
},
327318
"nbformat": 4,

examples/dataschema.ipynb

Lines changed: 156 additions & 108 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@ fixable = ["ALL"]
5353

5454
[dependency-groups]
5555
dev = [
56+
"jupyter>=1.1.1",
5657
"pre-commit>=4.1.0",
58+
"rich>=14.1.0",
59+
"ruff>=0.13.1",
5760
]
5861

5962

src/oqd_dataschema/__init__.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2024-2025 Open Quantum Design
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .base import Dataset, GroupBase, GroupRegistry
16+
from .datastore import Datastore
17+
from .groups import (
18+
ExpectationValueDataGroup,
19+
MeasurementOutcomesDataGroup,
20+
OQDTestbenchDataGroup,
21+
SinaraRawDataGroup,
22+
)
23+
24+
########################################################################################
25+
26+
__all__ = [
27+
"Dataset",
28+
"Datastore",
29+
"GroupBase",
30+
"GroupRegistry",
31+
"ExpectationValueDataGroup",
32+
"MeasurementOutcomesDataGroup",
33+
"OQDTestbenchDataGroup",
34+
"SinaraRawDataGroup",
35+
]

src/oqd_dataschema/base.py

Lines changed: 39 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,27 @@
1313
# limitations under the License.
1414

1515
# %%
16-
from typing import Any, Literal, Optional, Type, Union
16+
import warnings
17+
from typing import Annotated, Any, Literal, Optional, Union
1718

1819
import numpy as np
1920
from bidict import bidict
2021
from pydantic import (
2122
BaseModel,
2223
ConfigDict,
24+
Discriminator,
2325
Field,
2426
TypeAdapter,
2527
model_validator,
2628
)
2729

30+
########################################################################################
31+
32+
__all__ = ["GroupBase", "Dataset", "GroupRegistry"]
33+
34+
########################################################################################
35+
36+
2837
# %%
2938
mapping = bidict(
3039
{
@@ -39,7 +48,7 @@
3948
)
4049

4150

42-
class Group(BaseModel, extra="forbid"):
51+
class GroupBase(BaseModel, extra="forbid"):
4352
"""
4453
Schema representation for a group object within an HDF5 file.
4554
@@ -144,63 +153,40 @@ def validate_data_matches_shape_dtype(self):
144153
return self
145154

146155

147-
class GroupRegistry:
148-
"""Registry for managing group types dynamically"""
156+
class MetaGroupRegistry(type):
157+
def __new__(cls, clsname, superclasses, attributedict):
158+
attributedict["groups"] = dict()
159+
return super().__new__(cls, clsname, superclasses, attributedict)
149160

150-
_types: dict[str, Type[Group]] = {}
151-
_union_cache = None
161+
def register(cls, group):
162+
if not issubclass(group, GroupBase):
163+
raise TypeError("You may only register subclasses of GroupBase.")
152164

153-
@classmethod
154-
def register(cls, group_type: Type[Group]):
155-
"""Register a new group type"""
156-
import warnings
157-
158-
type_name = group_type.__name__
159-
160-
# Check if type is already registered
161-
if type_name in cls._types:
162-
existing_type = cls._types[type_name]
163-
if existing_type is not group_type: # Different class with same name
164-
warnings.warn(
165-
f"Group type '{type_name}' is already registered. "
166-
f"Overwriting {existing_type} with {group_type}.",
167-
UserWarning,
168-
stacklevel=2,
169-
)
170-
171-
cls._types[type_name] = group_type
172-
cls._union_cache = None # Invalidate cache
165+
if group.__name__ in cls.groups.keys():
166+
warnings.warn(
167+
f"Overwriting previously registered `{group.__name__}` group of the same name.",
168+
UserWarning,
169+
stacklevel=2,
170+
)
173171

174-
@classmethod
175-
def get_union(cls):
176-
"""Get the current Union of all registered types"""
177-
if cls._union_cache is None:
178-
if not cls._types:
179-
raise ValueError("No group types registered")
172+
cls.groups[group.__name__] = group
180173

181-
type_list = list(cls._types.values())
182-
if len(type_list) == 1:
183-
cls._union_cache = type_list[0]
184-
else:
185-
cls._union_cache = Union[tuple(type_list)]
174+
def clear(cls):
175+
"""Clear all registered types (useful for testing)"""
176+
cls.groups.clear()
186177

187-
return cls._union_cache
178+
@property
179+
def union(cls):
180+
"""Get the current Union of all registered types"""
181+
return Annotated[
182+
Union[tuple(cls.groups.values())], Discriminator(discriminator="class_")
183+
]
188184

189-
@classmethod
190-
def get_adapter(cls):
185+
@property
186+
def adapter(cls):
191187
"""Get TypeAdapter for current registered types"""
192-
from typing import Annotated
188+
return TypeAdapter(cls.union)
193189

194-
union_type = cls.get_union()
195-
return TypeAdapter(Annotated[union_type, Field(discriminator="class_")])
196190

197-
@classmethod
198-
def clear(cls):
199-
"""Clear all registered types (useful for testing)"""
200-
cls._types.clear()
201-
cls._union_cache = None
202-
203-
@classmethod
204-
def list_types(cls):
205-
"""List all registered type names"""
206-
return list(cls._types.keys())
191+
class GroupRegistry(metaclass=MetaGroupRegistry):
192+
pass

src/oqd_dataschema/datastore.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@
2222
from pydantic import BaseModel, model_validator
2323
from pydantic.types import TypeVar
2424

25-
from oqd_dataschema.base import Dataset, Group, GroupRegistry
25+
from oqd_dataschema.base import Dataset, GroupBase, GroupRegistry
26+
27+
########################################################################################
28+
29+
__all__ = ["Datastore"]
30+
31+
########################################################################################
2632

2733

2834
# %%
@@ -44,16 +50,17 @@ def validate_groups(cls, data):
4450
if isinstance(data, dict) and "groups" in data:
4551
# Get the current adapter from registry
4652
try:
47-
adapter = GroupRegistry.get_adapter()
4853
validated_groups = {}
4954

5055
for key, group_data in data["groups"].items():
51-
if isinstance(group_data, Group):
56+
if isinstance(group_data, GroupBase):
5257
# Already a Group instance
5358
validated_groups[key] = group_data
5459
elif isinstance(group_data, dict):
5560
# Parse dict using discriminated union
56-
validated_groups[key] = adapter.validate_python(group_data)
61+
validated_groups[key] = GroupRegistry.adapter.validate_python(
62+
group_data
63+
)
5764
else:
5865
raise ValueError(
5966
f"Invalid group data for key '{key}': {type(group_data)}"

src/oqd_dataschema/groups.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,21 @@
1313
# limitations under the License.
1414

1515

16-
from oqd_dataschema.base import Dataset, Group
16+
from oqd_dataschema.base import Dataset, GroupBase
1717

18+
########################################################################################
1819

19-
class SinaraRawDataGroup(Group):
20+
__all__ = [
21+
"SinaraRawDataGroup",
22+
"MeasurementOutcomesDataGroup",
23+
"ExpectationValueDataGroup",
24+
"OQDTestbenchDataGroup",
25+
]
26+
27+
########################################################################################
28+
29+
30+
class SinaraRawDataGroup(GroupBase):
2031
"""
2132
Example `Group` for raw data from the Sinara real-time control system.
2233
This is a placeholder for demonstration and development.
@@ -25,7 +36,7 @@ class SinaraRawDataGroup(Group):
2536
camera_images: Dataset
2637

2738

28-
class MeasurementOutcomesDataGroup(Group):
39+
class MeasurementOutcomesDataGroup(GroupBase):
2940
"""
3041
Example `Group` for processed data classifying the readout of the state.
3142
This is a placeholder for demonstration and development.
@@ -34,7 +45,7 @@ class MeasurementOutcomesDataGroup(Group):
3445
outcomes: Dataset
3546

3647

37-
class ExpectationValueDataGroup(Group):
48+
class ExpectationValueDataGroup(GroupBase):
3849
"""
3950
Example `Group` for processed data calculating the expectation values.
4051
This is a placeholder for demonstration and development.
@@ -43,7 +54,7 @@ class ExpectationValueDataGroup(Group):
4354
expectation_value: Dataset
4455

4556

46-
class OQDTestbenchDataGroup(Group):
57+
class OQDTestbenchDataGroup(GroupBase):
4758
""" """
4859

4960
time: Dataset

tests/test_typeadapt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import numpy as np
2020

21-
from oqd_dataschema.base import Dataset, Group
21+
from oqd_dataschema.base import Dataset, GroupBase
2222
from oqd_dataschema.datastore import Datastore
2323
from oqd_dataschema.groups import (
2424
SinaraRawDataGroup,
@@ -27,7 +27,7 @@
2727

2828
# %%
2929
def test_adapt():
30-
class TestNewGroup(Group):
30+
class TestNewGroup(GroupBase):
3131
""" """
3232

3333
array: Dataset

0 commit comments

Comments
 (0)