Skip to content

Commit 1d82e3e

Browse files
committed
[fix, feat] Added method to Datastore to add and update groups. fixed datastore model_dump_hdf5 ignore attrs and class_ when iterating through fields
1 parent 11abee3 commit 1d82e3e

File tree

3 files changed

+51
-6
lines changed

3 files changed

+51
-6
lines changed

src/oqd_dataschema/base.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def _is_dataset_type(cls, type_):
189189

190190
@_validator_from_condition
191191
def _constrain_dtype(dataset, *, dtype_constraint=None):
192+
"""Constrains the dtype of a dataset"""
192193
if (not isinstance(dtype_constraint, str)) and isinstance(
193194
dtype_constraint, Sequence
194195
):
@@ -204,6 +205,7 @@ def _constrain_dtype(dataset, *, dtype_constraint=None):
204205

205206
@_validator_from_condition
206207
def _constraint_dim(dataset, *, min_dim=None, max_dim=None):
208+
"""Constrains the dimension of a dataset"""
207209
if min_dim is not None and max_dim is not None and min_dim > max_dim:
208210
raise ValueError("Impossible to satisfy dimension constraints on dataset.")
209211

@@ -219,6 +221,7 @@ def _constraint_dim(dataset, *, min_dim=None, max_dim=None):
219221

220222
@_validator_from_condition
221223
def _constraint_shape(dataset, *, shape_constraint=None):
224+
"""Constrains the shape of a dataset"""
222225
if shape_constraint and not _flex_shape_equal(shape_constraint, dataset.shape):
223226
raise ValueError(
224227
f"Expected shape to be {shape_constraint}, but got {dataset.shape}."
@@ -228,6 +231,7 @@ def _constraint_shape(dataset, *, shape_constraint=None):
228231
def condataset(
229232
*, shape_constraint=None, dtype_constraint=None, min_dim=None, max_dim=None
230233
):
234+
"""Implements dtype, dimension and shape constrains on the dataset."""
231235
return Annotated[
232236
CastDataset,
233237
AfterValidator(_constrain_dtype(dtype_constraint=dtype_constraint)),
@@ -312,11 +316,16 @@ def __init_subclass__(cls, **kwargs):
312316

313317

314318
class MetaGroupRegistry(type):
319+
"""
320+
Metaclass for the GroupRegistry
321+
"""
322+
315323
def __new__(cls, clsname, superclasses, attributedict):
316324
attributedict["groups"] = dict()
317325
return super().__new__(cls, clsname, superclasses, attributedict)
318326

319327
def register(cls, group):
328+
"""Registers a group into the GroupRegistry."""
320329
if not issubclass(group, GroupBase):
321330
raise TypeError("You may only register subclasses of GroupBase.")
322331

@@ -347,6 +356,10 @@ def adapter(cls):
347356

348357

349358
class GroupRegistry(metaclass=MetaGroupRegistry):
359+
"""
360+
Represents the GroupRegistry
361+
"""
362+
350363
pass
351364

352365

src/oqd_dataschema/datastore.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,11 @@
3737
# %%
3838
class Datastore(BaseModel, extra="forbid"):
3939
"""
40-
Saves the model and its associated data to an HDF5 file.
41-
This method serializes the model's data and attributes into an HDF5 file
42-
at the specified filepath.
40+
Class representing a datastore with restricted HDF5 format.
4341
4442
Attributes:
45-
filepath (pathlib.Path): The path to the HDF5 file where the model data will be saved.
43+
groups (Dict[str,Group]): groups of data.
44+
attrs (Attrs): attributes of the datastore.
4645
"""
4746

4847
groups: Dict[str, Any]
@@ -51,6 +50,7 @@ class Datastore(BaseModel, extra="forbid"):
5150

5251
@classmethod
5352
def _validate_group(cls, key, group):
53+
"""Helper function for validating group to be of type Group registered in the GroupRegistry."""
5454
if isinstance(group, GroupBase):
5555
return group
5656

@@ -62,6 +62,7 @@ def _validate_group(cls, key, group):
6262
@field_validator("groups", mode="before")
6363
@classmethod
6464
def validate_groups(cls, data):
65+
"""Validates groups to be of type Group registered in the GroupRegistry."""
6566
if GroupRegistry.groups == {}:
6667
raise ValueError(
6768
"No group types available. Register group types before creating Datastore."
@@ -71,6 +72,7 @@ def validate_groups(cls, data):
7172
return validated_groups
7273

7374
def _dump_group(self, h5datastore, gkey, group):
75+
"""Helper function for dumping Group."""
7476
# remove existing group
7577
if gkey in h5datastore.keys():
7678
del h5datastore[gkey]
@@ -89,15 +91,20 @@ def _dump_group(self, h5datastore, gkey, group):
8991

9092
# dump group data
9193
for dkey, dataset in group.__dict__.items():
94+
if dkey in ["attr", "class_"]:
95+
continue
96+
9297
# if group field contain dictionary of Dataset
9398
if isinstance(dataset, dict):
9499
h5_subgroup = h5_group.create_group(dkey)
95100
for ddkey, ddataset in dataset.items():
96101
self._dump_dataset(h5_subgroup, ddkey, ddataset)
97-
else:
98-
self._dump_dataset(h5_group, dkey, dataset)
102+
continue
103+
104+
self._dump_dataset(h5_group, dkey, dataset)
99105

100106
def _dump_dataset(self, h5group, dkey, dataset):
107+
"""Helper function for dumping Dataset."""
101108
if not isinstance(dataset, Dataset):
102109
raise ValueError("Group data field is not a Dataset.")
103110

@@ -132,6 +139,9 @@ def model_dump_hdf5(self, filepath: pathlib.Path, mode: Literal["w", "a"] = "w")
132139

133140
# dump each group
134141
for gkey, group in self.groups.items():
142+
if gkey in ["attr", "class_"]:
143+
continue
144+
135145
self._dump_group(f, gkey, group)
136146

137147
@classmethod
@@ -177,7 +187,26 @@ def model_validate_hdf5(cls, filepath: pathlib.Path):
177187
return self
178188

179189
def __getitem__(self, key):
190+
"""Overloads indexing to retrieve elements in groups."""
180191
return self.groups.__getitem__(key)
181192

182193
def __iter__(self):
194+
"""Overloads iter to iterate over elements in groups."""
183195
return self.groups.items().__iter__()
196+
197+
def add(self, **groups):
198+
"""Adds a new groups to the datastore."""
199+
for k, v in groups.items():
200+
if k in self.groups.keys():
201+
raise ValueError(
202+
"Key already exist in the datastore, use `update` instead if intending to overwrite past data."
203+
)
204+
self.groups[k] = v
205+
206+
def update(self, **groups):
207+
"""Updates groups in the datastore, overwriting past values."""
208+
for k, v in groups.items():
209+
self.groups[k] = v
210+
211+
212+
# %%

src/oqd_dataschema/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424

2525
def _flex_shape_equal(shape1, shape2):
26+
"""Helper function for comparing concrete and flex shapes."""
2627
return len(shape1) == len(shape2) and reduce(
2728
lambda x, y: x and y,
2829
map(
@@ -33,6 +34,8 @@ def _flex_shape_equal(shape1, shape2):
3334

3435

3536
def _validator_from_condition(f):
37+
"""Helper decorator for turning a condition into a validation."""
38+
3639
def _wrapped_validator(*args, **kwargs):
3740
def _wrapped_condition(model):
3841
f(model, *args, **kwargs)

0 commit comments

Comments
 (0)