Skip to content

Commit 6291f23

Browse files
committed
[feat] Added method to Datastore to add and update groups.
1 parent 11abee3 commit 6291f23

File tree

3 files changed

+42
-4
lines changed

3 files changed

+42
-4
lines changed

src/oqd_dataschema/base.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def _is_dataset_type(cls, type_):
189189

190190
@_validator_from_condition
191191
def _constrain_dtype(dataset, *, dtype_constraint=None):
192+
"""Constrains the dtype of a dataset"""
192193
if (not isinstance(dtype_constraint, str)) and isinstance(
193194
dtype_constraint, Sequence
194195
):
@@ -204,6 +205,7 @@ def _constrain_dtype(dataset, *, dtype_constraint=None):
204205

205206
@_validator_from_condition
206207
def _constraint_dim(dataset, *, min_dim=None, max_dim=None):
208+
"""Constrains the dimension of a dataset"""
207209
if min_dim is not None and max_dim is not None and min_dim > max_dim:
208210
raise ValueError("Impossible to satisfy dimension constraints on dataset.")
209211

@@ -219,6 +221,7 @@ def _constraint_dim(dataset, *, min_dim=None, max_dim=None):
219221

220222
@_validator_from_condition
221223
def _constraint_shape(dataset, *, shape_constraint=None):
224+
"""Constrains the shape of a dataset"""
222225
if shape_constraint and not _flex_shape_equal(shape_constraint, dataset.shape):
223226
raise ValueError(
224227
f"Expected shape to be {shape_constraint}, but got {dataset.shape}."
@@ -228,6 +231,7 @@ def _constraint_shape(dataset, *, shape_constraint=None):
228231
def condataset(
229232
*, shape_constraint=None, dtype_constraint=None, min_dim=None, max_dim=None
230233
):
234+
"""Implements dtype, dimension and shape constrains on the dataset."""
231235
return Annotated[
232236
CastDataset,
233237
AfterValidator(_constrain_dtype(dtype_constraint=dtype_constraint)),
@@ -312,11 +316,16 @@ def __init_subclass__(cls, **kwargs):
312316

313317

314318
class MetaGroupRegistry(type):
319+
"""
320+
Metaclass for the GroupRegistry
321+
"""
322+
315323
def __new__(cls, clsname, superclasses, attributedict):
316324
attributedict["groups"] = dict()
317325
return super().__new__(cls, clsname, superclasses, attributedict)
318326

319327
def register(cls, group):
328+
"""Registers a group into the GroupRegistry."""
320329
if not issubclass(group, GroupBase):
321330
raise TypeError("You may only register subclasses of GroupBase.")
322331

@@ -347,6 +356,10 @@ def adapter(cls):
347356

348357

349358
class GroupRegistry(metaclass=MetaGroupRegistry):
359+
"""
360+
Represents the GroupRegistry
361+
"""
362+
350363
pass
351364

352365

src/oqd_dataschema/datastore.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,11 @@
3737
# %%
3838
class Datastore(BaseModel, extra="forbid"):
3939
"""
40-
Saves the model and its associated data to an HDF5 file.
41-
This method serializes the model's data and attributes into an HDF5 file
42-
at the specified filepath.
40+
Class representing a datastore with restricted HDF5 format.
4341
4442
Attributes:
45-
filepath (pathlib.Path): The path to the HDF5 file where the model data will be saved.
43+
groups (Dict[str,Group]): groups of data.
44+
attrs (Attrs): attributes of the datastore.
4645
"""
4746

4847
groups: Dict[str, Any]
@@ -51,6 +50,7 @@ class Datastore(BaseModel, extra="forbid"):
5150

5251
@classmethod
5352
def _validate_group(cls, key, group):
53+
"""Helper function for validating group to be of type Group registered in the GroupRegistry."""
5454
if isinstance(group, GroupBase):
5555
return group
5656

@@ -62,6 +62,7 @@ def _validate_group(cls, key, group):
6262
@field_validator("groups", mode="before")
6363
@classmethod
6464
def validate_groups(cls, data):
65+
"""Validates groups to be of type Group registered in the GroupRegistry."""
6566
if GroupRegistry.groups == {}:
6667
raise ValueError(
6768
"No group types available. Register group types before creating Datastore."
@@ -71,6 +72,7 @@ def validate_groups(cls, data):
7172
return validated_groups
7273

7374
def _dump_group(self, h5datastore, gkey, group):
75+
"""Helper function for dumping Group."""
7476
# remove existing group
7577
if gkey in h5datastore.keys():
7678
del h5datastore[gkey]
@@ -98,6 +100,7 @@ def _dump_group(self, h5datastore, gkey, group):
98100
self._dump_dataset(h5_group, dkey, dataset)
99101

100102
def _dump_dataset(self, h5group, dkey, dataset):
103+
"""Helper function for dumping Dataset."""
101104
if not isinstance(dataset, Dataset):
102105
raise ValueError("Group data field is not a Dataset.")
103106

@@ -177,7 +180,26 @@ def model_validate_hdf5(cls, filepath: pathlib.Path):
177180
return self
178181

179182
def __getitem__(self, key):
183+
"""Overloads indexing to retrieve elements in groups."""
180184
return self.groups.__getitem__(key)
181185

182186
def __iter__(self):
187+
"""Overloads iter to iterate over elements in groups."""
183188
return self.groups.items().__iter__()
189+
190+
def add(self, **groups):
191+
"""Adds a new groups to the datastore."""
192+
for k, v in groups.items():
193+
if k in self.groups.keys():
194+
raise ValueError(
195+
"Key already exist in the datastore, use `update` instead if intending to overwrite past data."
196+
)
197+
self.groups[k] == v
198+
199+
def update(self, **groups):
200+
"""Updates groups in the datastore, overwriting past values."""
201+
for k, v in groups.items():
202+
self.groups[k] == v
203+
204+
205+
# %%

src/oqd_dataschema/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424

2525
def _flex_shape_equal(shape1, shape2):
26+
"""Helper function for comparing concrete and flex shapes."""
2627
return len(shape1) == len(shape2) and reduce(
2728
lambda x, y: x and y,
2829
map(
@@ -33,6 +34,8 @@ def _flex_shape_equal(shape1, shape2):
3334

3435

3536
def _validator_from_condition(f):
37+
"""Helper decorator for turning a condition into a validation."""
38+
3639
def _wrapped_validator(*args, **kwargs):
3740
def _wrapped_condition(model):
3841
f(model, *args, **kwargs)

0 commit comments

Comments
 (0)