Skip to content

Commit c17c1a5

Browse files
Merge pull request #5 from OpenQuantumDesign/registry
Add registry for dynamic type definition in Datastore
2 parents 849099f + be59ac4 commit c17c1a5

File tree

7 files changed

+522
-31
lines changed

7 files changed

+522
-31
lines changed

examples/custom_group.ipynb

Lines changed: 329 additions & 0 deletions
Large diffs are not rendered by default.

examples/test_h5pydantic.ipynb renamed to examples/dataschema.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@
501501
],
502502
"metadata": {
503503
"kernelspec": {
504-
"display_name": ".venv",
504+
"display_name": "oqd-dataschema",
505505
"language": "python",
506506
"name": "python3"
507507
},

src/oqd_dataschema/base.py

Lines changed: 83 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
# limitations under the License.
1414

1515
# %%
16-
from typing import Any, Literal, Optional, Union
16+
from typing import Any, Literal, Optional, Type, Union
1717

1818
import numpy as np
1919
from bidict import bidict
2020
from pydantic import (
2121
BaseModel,
2222
ConfigDict,
2323
Field,
24+
TypeAdapter,
2425
model_validator,
2526
)
2627

@@ -38,7 +39,7 @@
3839
)
3940

4041

41-
class Group(BaseModel):
42+
class Group(BaseModel, extra="forbid"):
4243
"""
4344
Schema representation for a group object within an HDF5 file.
4445
@@ -56,8 +57,25 @@ class Group(BaseModel):
5657

5758
attrs: Optional[dict[str, Union[int, float, str, complex]]] = {}
5859

60+
def __init_subclass__(cls, **kwargs):
61+
super().__init_subclass__(**kwargs)
62+
cls.__annotations__["class_"] = Literal[cls.__name__]
63+
setattr(cls, "class_", cls.__name__)
5964

60-
class Dataset(BaseModel):
65+
# Auto-register new group types
66+
GroupRegistry.register(cls)
67+
68+
@model_validator(mode="before")
69+
@classmethod
70+
def auto_assign_class(cls, data):
71+
if isinstance(data, BaseModel):
72+
return data
73+
if isinstance(data, dict):
74+
data["class_"] = cls.__name__
75+
return data
76+
77+
78+
class Dataset(BaseModel, extra="forbid"):
6179
"""
6280
Schema representation for a dataset object to be saved within an HDF5 file.
6381
@@ -110,12 +128,6 @@ def validate_and_update(cls, values: dict):
110128

111129
return values
112130

113-
# else:
114-
# assert data.dtype == dtype and data.shape == shape
115-
116-
# else:
117-
# raise ValueError("Must provide either `dtype` and `shape` or `data`.")
118-
119131
@model_validator(mode="after")
120132
def validate_data_matches_shape_dtype(self):
121133
"""Ensure that `data` matches `dtype` and `shape`."""
@@ -130,3 +142,65 @@ def validate_data_matches_shape_dtype(self):
130142
f"Expected shape {self.shape}, but got {self.data.shape}."
131143
)
132144
return self
145+
146+
147+
class GroupRegistry:
148+
"""Registry for managing group types dynamically"""
149+
150+
_types: dict[str, Type[Group]] = {}
151+
_union_cache = None
152+
153+
@classmethod
154+
def register(cls, group_type: Type[Group]):
155+
"""Register a new group type"""
156+
import warnings
157+
158+
type_name = group_type.__name__
159+
160+
# Check if type is already registered
161+
if type_name in cls._types:
162+
existing_type = cls._types[type_name]
163+
if existing_type is not group_type: # Different class with same name
164+
warnings.warn(
165+
f"Group type '{type_name}' is already registered. "
166+
f"Overwriting {existing_type} with {group_type}.",
167+
UserWarning,
168+
stacklevel=2,
169+
)
170+
171+
cls._types[type_name] = group_type
172+
cls._union_cache = None # Invalidate cache
173+
174+
@classmethod
175+
def get_union(cls):
176+
"""Get the current Union of all registered types"""
177+
if cls._union_cache is None:
178+
if not cls._types:
179+
raise ValueError("No group types registered")
180+
181+
type_list = list(cls._types.values())
182+
if len(type_list) == 1:
183+
cls._union_cache = type_list[0]
184+
else:
185+
cls._union_cache = Union[tuple(type_list)]
186+
187+
return cls._union_cache
188+
189+
@classmethod
190+
def get_adapter(cls):
191+
"""Get TypeAdapter for current registered types"""
192+
from typing import Annotated
193+
194+
union_type = cls.get_union()
195+
return TypeAdapter(Annotated[union_type, Field(discriminator="class_")])
196+
197+
@classmethod
198+
def clear(cls):
199+
"""Clear all registered types (useful for testing)"""
200+
cls._types.clear()
201+
cls._union_cache = None
202+
203+
@classmethod
204+
def list_types(cls):
205+
"""List all registered type names"""
206+
return list(cls._types.keys())

src/oqd_dataschema/datastore.py

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,21 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
# %%
16+
1517
import pathlib
16-
from typing import Union
18+
from typing import Any, Dict, Literal, Optional
1719

1820
import h5py
1921
import numpy as np
20-
from pydantic import BaseModel
21-
22-
from oqd_dataschema.base import Dataset
23-
from oqd_dataschema.groups import (
24-
ExpectationValueDataGroup,
25-
MeasurementOutcomesDataGroup,
26-
SinaraRawDataGroup,
27-
)
22+
from pydantic import BaseModel, model_validator
23+
from pydantic.types import TypeVar
2824

29-
GroupSubtypes = Union[
30-
SinaraRawDataGroup, MeasurementOutcomesDataGroup, ExpectationValueDataGroup
31-
]
25+
from oqd_dataschema.base import Dataset, Group, GroupRegistry
3226

3327

34-
class Datastore(BaseModel):
28+
# %%
29+
class Datastore(BaseModel, extra="forbid"):
3530
"""
3631
Saves the model and its associated data to an HDF5 file.
3732
This method serializes the model's data and attributes into an HDF5 file
@@ -41,9 +36,41 @@ class Datastore(BaseModel):
4136
filepath (pathlib.Path): The path to the HDF5 file where the model data will be saved.
4237
"""
4338

44-
groups: dict[str, GroupSubtypes]
39+
groups: Dict[str, Any]
4540

46-
def model_dump_hdf5(self, filepath: pathlib.Path):
41+
@model_validator(mode="before")
42+
@classmethod
43+
def validate_groups(cls, data):
44+
if isinstance(data, dict) and "groups" in data:
45+
# Get the current adapter from registry
46+
try:
47+
adapter = GroupRegistry.get_adapter()
48+
validated_groups = {}
49+
50+
for key, group_data in data["groups"].items():
51+
if isinstance(group_data, Group):
52+
# Already a Group instance
53+
validated_groups[key] = group_data
54+
elif isinstance(group_data, dict):
55+
# Parse dict using discriminated union
56+
validated_groups[key] = adapter.validate_python(group_data)
57+
else:
58+
raise ValueError(
59+
f"Invalid group data for key '{key}': {type(group_data)}"
60+
)
61+
62+
data["groups"] = validated_groups
63+
64+
except ValueError as e:
65+
if "No group types registered" in str(e):
66+
raise ValueError(
67+
"No group types available. Register group types before creating Datastore."
68+
)
69+
raise
70+
71+
return data
72+
73+
def model_dump_hdf5(self, filepath: pathlib.Path, mode: Literal["w", "a"] = "a"):
4774
"""
4875
Saves the model and its associated data to an HDF5 file.
4976
This method serializes the model's data and attributes into an HDF5 file
@@ -54,8 +81,11 @@ def model_dump_hdf5(self, filepath: pathlib.Path):
5481
"""
5582
filepath.parent.mkdir(exist_ok=True, parents=True)
5683

57-
with h5py.File(filepath, "a") as f:
84+
with h5py.File(filepath, mode) as f:
85+
# store the model JSON schema
5886
f.attrs["model"] = self.model_dump_json()
87+
88+
# store each group
5989
for gkey, group in self.groups.items():
6090
if gkey in f.keys():
6191
del f[gkey]
@@ -71,7 +101,9 @@ def model_dump_hdf5(self, filepath: pathlib.Path):
71101
h5_dataset.attrs[akey] = attr
72102

73103
@classmethod
74-
def model_validate_hdf5(cls, filepath: pathlib.Path):
104+
def model_validate_hdf5(
105+
cls, filepath: pathlib.Path, types: Optional[TypeVar] = None
106+
):
75107
"""
76108
Loads the model from an HDF5 file at the specified filepath.
77109
@@ -80,9 +112,11 @@ def model_validate_hdf5(cls, filepath: pathlib.Path):
80112
"""
81113
with h5py.File(filepath, "r") as f:
82114
self = cls.model_validate_json(f.attrs["model"])
115+
116+
# loop through all groups in the model schema and load HDF5 store
83117
for gkey, group in self.groups.items():
84118
for dkey, val in group.__dict__.items():
85-
if dkey == "attrs":
119+
if dkey in ("attrs", "class_"):
86120
continue
87121
group.__dict__[dkey].data = np.array(f[gkey][dkey][()])
88122
return self

src/oqd_dataschema/groups.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
1516
from oqd_dataschema.base import Dataset, Group
1617

1718

@@ -43,8 +44,7 @@ class ExpectationValueDataGroup(Group):
4344

4445

4546
class OQDTestbenchDataGroup(Group):
46-
"""
47-
"""
47+
""" """
4848

4949
time: Dataset
50-
voltages: Dataset
50+
voltages: Dataset

tests/test_datastore.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,6 @@ def test_serialize_deserialize(dtype):
4848
data_reload = Datastore.model_validate_hdf5(filepath)
4949

5050
assert data_reload.groups["test"].camera_images.data.dtype == mapping[dtype]
51+
52+
53+
# %%

tests/test_typeadapt.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2024-2025 Open Quantum Design
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
# %%
17+
import pathlib
18+
19+
import numpy as np
20+
21+
from oqd_dataschema.base import Dataset, Group
22+
from oqd_dataschema.datastore import Datastore
23+
from oqd_dataschema.groups import (
24+
SinaraRawDataGroup,
25+
)
26+
27+
28+
# %%
29+
def test_adapt():
30+
class TestNewGroup(Group):
31+
""" """
32+
33+
array: Dataset
34+
35+
filepath = pathlib.Path("test.h5")
36+
37+
data = np.ones([10, 10]).astype("int64")
38+
group1 = TestNewGroup(array=Dataset(data=data))
39+
40+
data = np.ones([10, 10]).astype("int32")
41+
group2 = SinaraRawDataGroup(camera_images=Dataset(data=data))
42+
43+
datastore = Datastore(
44+
groups={
45+
"group1": group1,
46+
"group2": group2,
47+
}
48+
)
49+
datastore.model_dump_hdf5(filepath, mode="w")
50+
51+
Datastore.model_validate_hdf5(filepath)

0 commit comments

Comments
 (0)