3737# %%
3838class Datastore (BaseModel , extra = "forbid" ):
3939 """
40- Saves the model and its associated data to an HDF5 file.
41- This method serializes the model's data and attributes into an HDF5 file
42- at the specified filepath.
40+ Class representing a datastore with restricted HDF5 format.
4341
4442 Attributes:
45- filepath (pathlib.Path): The path to the HDF5 file where the model data will be saved.
43+ groups (Dict[str,Group]): groups of data.
44+ attrs (Attrs): attributes of the datastore.
4645 """
4746
4847 groups : Dict [str , Any ]
@@ -51,6 +50,7 @@ class Datastore(BaseModel, extra="forbid"):
5150
5251 @classmethod
5352 def _validate_group (cls , key , group ):
53+ """Helper function for validating group to be of type Group registered in the GroupRegistry."""
5454 if isinstance (group , GroupBase ):
5555 return group
5656
@@ -62,6 +62,7 @@ def _validate_group(cls, key, group):
6262 @field_validator ("groups" , mode = "before" )
6363 @classmethod
6464 def validate_groups (cls , data ):
65+ """Validates groups to be of type Group registered in the GroupRegistry."""
6566 if GroupRegistry .groups == {}:
6667 raise ValueError (
6768 "No group types available. Register group types before creating Datastore."
@@ -71,6 +72,7 @@ def validate_groups(cls, data):
7172 return validated_groups
7273
7374 def _dump_group (self , h5datastore , gkey , group ):
75+ """Helper function for dumping Group."""
7476 # remove existing group
7577 if gkey in h5datastore .keys ():
7678 del h5datastore [gkey ]
@@ -89,15 +91,20 @@ def _dump_group(self, h5datastore, gkey, group):
8991
9092 # dump group data
9193 for dkey , dataset in group .__dict__ .items ():
94+ if dkey in ["attr" , "class_" ]:
95+ continue
96+
9297 # if group field contain dictionary of Dataset
9398 if isinstance (dataset , dict ):
9499 h5_subgroup = h5_group .create_group (dkey )
95100 for ddkey , ddataset in dataset .items ():
96101 self ._dump_dataset (h5_subgroup , ddkey , ddataset )
97- else :
98- self ._dump_dataset (h5_group , dkey , dataset )
102+ continue
103+
104+ self ._dump_dataset (h5_group , dkey , dataset )
99105
100106 def _dump_dataset (self , h5group , dkey , dataset ):
107+ """Helper function for dumping Dataset."""
101108 if not isinstance (dataset , Dataset ):
102109 raise ValueError ("Group data field is not a Dataset." )
103110
@@ -132,6 +139,9 @@ def model_dump_hdf5(self, filepath: pathlib.Path, mode: Literal["w", "a"] = "w")
132139
133140 # dump each group
134141 for gkey , group in self .groups .items ():
142+ if gkey in ["attr" , "class_" ]:
143+ continue
144+
135145 self ._dump_group (f , gkey , group )
136146
137147 @classmethod
@@ -177,7 +187,26 @@ def model_validate_hdf5(cls, filepath: pathlib.Path):
177187 return self
178188
179189 def __getitem__ (self , key ):
190+ """Overloads indexing to retrieve elements in groups."""
180191 return self .groups .__getitem__ (key )
181192
182193 def __iter__ (self ):
194+ """Overloads iter to iterate over elements in groups."""
183195 return self .groups .items ().__iter__ ()
196+
197+ def add (self , ** groups ):
198+ """Adds a new groups to the datastore."""
199+ for k , v in groups .items ():
200+ if k in self .groups .keys ():
201+ raise ValueError (
202+ "Key already exist in the datastore, use `update` instead if intending to overwrite past data."
203+ )
204+ self .groups [k ] = v
205+
206+ def update (self , ** groups ):
207+ """Updates groups in the datastore, overwriting past values."""
208+ for k , v in groups .items ():
209+ self .groups [k ] = v
210+
211+
212+ # %%
0 commit comments