Skip to content

Commit 11abee3

Browse files
committed
[refactor] refactor datastore group validation and added comments
1 parent 92a5719 commit 11abee3

File tree

1 file changed

+52
-42
lines changed

1 file changed

+52
-42
lines changed

src/oqd_dataschema/datastore.py

Lines changed: 52 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -49,69 +49,69 @@ class Datastore(BaseModel, extra="forbid"):
4949

5050
attrs: Attrs = {}
5151

52-
@field_validator("groups", mode="before")
5352
@classmethod
54-
def validate_groups(cls, data):
55-
if isinstance(data, dict):
56-
# Get the current adapter from registry
57-
try:
58-
validated_groups = {}
59-
60-
for key, group_data in data.items():
61-
if isinstance(group_data, GroupBase):
62-
# Already a Group instance
63-
validated_groups[key] = group_data
64-
elif isinstance(group_data, dict):
65-
# Parse dict using discriminated union
66-
validated_groups[key] = GroupRegistry.adapter.validate_python(
67-
group_data
68-
)
69-
else:
70-
raise ValueError(
71-
f"Invalid group data for key '{key}': {type(group_data)}"
72-
)
53+
def _validate_group(cls, key, group):
54+
if isinstance(group, GroupBase):
55+
return group
7356

74-
data = validated_groups
57+
if isinstance(group, dict):
58+
return GroupRegistry.adapter.validate_python(group)
7559

76-
except ValueError as e:
77-
if "No group types registered" in str(e):
78-
raise ValueError(
79-
"No group types available. Register group types before creating Datastore."
80-
)
81-
raise
60+
raise ValueError(f"Key `{key}` contains invalid group data.")
8261

83-
return data
62+
@field_validator("groups", mode="before")
63+
@classmethod
64+
def validate_groups(cls, data):
65+
if GroupRegistry.groups == {}:
66+
raise ValueError(
67+
"No group types available. Register group types before creating Datastore."
68+
)
69+
70+
validated_groups = {k: cls._validate_group(k, v) for k, v in data.items()}
71+
return validated_groups
8472

8573
def _dump_group(self, h5datastore, gkey, group):
74+
# remove existing group
8675
if gkey in h5datastore.keys():
8776
del h5datastore[gkey]
77+
78+
# create group
8879
h5_group = h5datastore.create_group(gkey)
8980

81+
# dump group schema
9082
h5_group.attrs["_group_schema"] = json.dumps(
9183
group.model_json_schema(), indent=2
9284
)
85+
86+
# dump group attributes
9387
for akey, attr in group.attrs.items():
9488
h5_group.attrs[akey] = attr
9589

90+
# dump group data
9691
for dkey, dataset in group.__dict__.items():
92+
# if group field contain dictionary of Dataset
9793
if isinstance(dataset, dict):
9894
h5_subgroup = h5_group.create_group(dkey)
9995
for ddkey, ddataset in dataset.items():
10096
self._dump_dataset(h5_subgroup, ddkey, ddataset)
101-
102-
self._dump_dataset(h5_group, dkey, dataset)
97+
else:
98+
self._dump_dataset(h5_group, dkey, dataset)
10399

104100
def _dump_dataset(self, h5group, dkey, dataset):
105-
if isinstance(dataset, Dataset):
106-
if dataset.dtype in "str":
107-
h5_dataset = h5group.create_dataset(
108-
dkey, data=dataset.data.astype(np.dtypes.BytesDType)
109-
)
110-
else:
111-
h5_dataset = h5group.create_dataset(dkey, data=dataset.data)
101+
if not isinstance(dataset, Dataset):
102+
raise ValueError("Group data field is not a Dataset.")
103+
104+
# dtype str converted to bytes when dumped (h5 compatibility)
105+
if dataset.dtype in "str":
106+
h5_dataset = h5group.create_dataset(
107+
dkey, data=dataset.data.astype(np.dtypes.BytesDType)
108+
)
109+
else:
110+
h5_dataset = h5group.create_dataset(dkey, data=dataset.data)
112111

113-
for akey, attr in dataset.attrs.items():
114-
h5_dataset.attrs[akey] = attr
112+
# dump dataset attributes
113+
for akey, attr in dataset.attrs.items():
114+
h5_dataset.attrs[akey] = attr
115115

116116
def model_dump_hdf5(self, filepath: pathlib.Path, mode: Literal["w", "a"] = "w"):
117117
"""
@@ -125,12 +125,12 @@ def model_dump_hdf5(self, filepath: pathlib.Path, mode: Literal["w", "a"] = "w")
125125
filepath.parent.mkdir(exist_ok=True, parents=True)
126126

127127
with h5py.File(filepath, mode) as f:
128-
# store the model JSON schema
128+
# dump the datastore signature
129129
f.attrs["_datastore_signature"] = self.model_dump_json(indent=2)
130130
for akey, attr in self.attrs.items():
131131
f.attrs[akey] = attr
132132

133-
# store each group
133+
# dump each group
134134
for gkey, group in self.groups.items():
135135
self._dump_group(f, gkey, group)
136136

@@ -143,26 +143,36 @@ def model_validate_hdf5(cls, filepath: pathlib.Path):
143143
filepath (pathlib.Path): The path to the HDF5 file where the model data will be read and validated from.
144144
"""
145145
with h5py.File(filepath, "r") as f:
146+
# Load datastore signature
146147
self = cls.model_validate_json(f.attrs["_datastore_signature"])
147148

148-
# loop through all groups in the model schema and load HDF5 store
149+
# loop through all groups in the model schema and load the data
149150
for gkey, group in self:
150151
for dkey in group.__class__.model_fields:
152+
# ignore attrs and class_ fields
151153
if dkey in ("attrs", "class_"):
152154
continue
153155

156+
# load Dataset data
154157
if isinstance(group.__dict__[dkey], Dataset):
155158
group.__dict__[dkey].data = np.array(f[gkey][dkey][()]).astype(
156159
DTypes.get(group.__dict__[dkey].dtype).value
157160
)
161+
continue
158162

163+
# load data for dict of Dataset
159164
if isinstance(group.__dict__[dkey], dict):
160165
for ddkey in group.__dict__[dkey]:
161166
group.__dict__[dkey][ddkey].data = np.array(
162167
f[gkey][dkey][ddkey][()]
163168
).astype(
164169
DTypes.get(group.__dict__[dkey][ddkey].dtype).value
165170
)
171+
continue
172+
173+
raise TypeError(
174+
"Group data fields must be of type Dataset or dict of Dataset."
175+
)
166176

167177
return self
168178

0 commit comments

Comments
 (0)