Skip to content

Commit de65b44

Browse files
committed
[feat] Added Attrs with validation preventing special keys. Added attrs to Datastore. Added json schema of groups to be stored in hdf5. Changed Datastore groups validator to a field validator.
1 parent 130de0a commit de65b44

File tree

2 files changed

+38
-10
lines changed

2 files changed

+38
-10
lines changed

src/oqd_dataschema/base.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from bidict import bidict
2121
from pydantic import (
2222
BaseModel,
23+
BeforeValidator,
2324
ConfigDict,
2425
Discriminator,
2526
Field,
@@ -34,6 +35,23 @@
3435
########################################################################################
3536

3637

38+
invalid_attrs = ["_model_signature", "_model_json"]
39+
40+
41+
def _valid_attr_key(value):
42+
if value in invalid_attrs:
43+
raise KeyError
44+
return value
45+
46+
47+
Attrs = Optional[
48+
dict[
49+
Annotated[str, BeforeValidator(_valid_attr_key)],
50+
Union[int, float, str, complex],
51+
]
52+
]
53+
54+
3755
# %%
3856
mapping = bidict(
3957
{
@@ -64,7 +82,7 @@ class GroupBase(BaseModel, extra="forbid"):
6482
```
6583
"""
6684

67-
attrs: Optional[dict[str, Union[int, float, str, complex]]] = {}
85+
attrs: Attrs = {}
6886

6987
def __init_subclass__(cls, **kwargs):
7088
super().__init_subclass__(**kwargs)
@@ -100,7 +118,7 @@ class Dataset(BaseModel, extra="forbid"):
100118
shape: Optional[tuple[int, ...]] = None
101119
data: Optional[Any] = Field(default=None, exclude=True)
102120

103-
attrs: Optional[dict[str, Union[int, float, str, complex]]] = {}
121+
attrs: Attrs = {}
104122

105123
model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True)
106124

src/oqd_dataschema/datastore.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,19 @@
1414

1515
# %%
1616

17+
import json
1718
import pathlib
1819
from typing import Any, Dict, Literal, Optional
1920

2021
import h5py
2122
import numpy as np
22-
from pydantic import BaseModel, model_validator
23+
from pydantic import (
24+
BaseModel,
25+
field_validator,
26+
)
2327
from pydantic.types import TypeVar
2428

25-
from oqd_dataschema.base import Dataset, GroupBase, GroupRegistry
29+
from oqd_dataschema.base import Attrs, Dataset, GroupBase, GroupRegistry
2630

2731
########################################################################################
2832

@@ -44,15 +48,17 @@ class Datastore(BaseModel, extra="forbid"):
4448

4549
groups: Dict[str, Any]
4650

47-
@model_validator(mode="before")
51+
attrs: Attrs = {}
52+
53+
@field_validator("groups", mode="before")
4854
@classmethod
4955
def validate_groups(cls, data):
50-
if isinstance(data, dict) and "groups" in data:
56+
if isinstance(data, dict):
5157
# Get the current adapter from registry
5258
try:
5359
validated_groups = {}
5460

55-
for key, group_data in data["groups"].items():
61+
for key, group_data in data.items():
5662
if isinstance(group_data, GroupBase):
5763
# Already a Group instance
5864
validated_groups[key] = group_data
@@ -66,7 +72,7 @@ def validate_groups(cls, data):
6672
f"Invalid group data for key '{key}': {type(group_data)}"
6773
)
6874

69-
data["groups"] = validated_groups
75+
data = validated_groups
7076

7177
except ValueError as e:
7278
if "No group types registered" in str(e):
@@ -90,13 +96,17 @@ def model_dump_hdf5(self, filepath: pathlib.Path, mode: Literal["w", "a"] = "a")
9096

9197
with h5py.File(filepath, mode) as f:
9298
# store the model JSON schema
93-
f.attrs["model"] = self.model_dump_json()
99+
f.attrs["_model_signature"] = self.model_dump_json()
100+
for akey, attr in self.attrs.items():
101+
f.attrs[akey] = attr
94102

95103
# store each group
96104
for gkey, group in self.groups.items():
97105
if gkey in f.keys():
98106
del f[gkey]
99107
h5_group = f.create_group(gkey)
108+
109+
h5_group.attrs["_model_schema"] = json.dumps(group.model_json_schema())
100110
for akey, attr in group.attrs.items():
101111
h5_group.attrs[akey] = attr
102112

@@ -118,7 +128,7 @@ def model_validate_hdf5(
118128
filepath (pathlib.Path): The path to the HDF5 file where the model data will be read and validated from.
119129
"""
120130
with h5py.File(filepath, "r") as f:
121-
self = cls.model_validate_json(f.attrs["model"])
131+
self = cls.model_validate_json(f.attrs["_model_signature"])
122132

123133
# loop through all groups in the model schema and load HDF5 store
124134
for gkey, group in self.groups.items():

0 commit comments

Comments
 (0)