3737# %%
3838class Datastore (BaseModel , extra = "forbid" ):
3939 """
40- Saves the model and its associated data to an HDF5 file.
41- This method serializes the model's data and attributes into an HDF5 file
42- at the specified filepath.
40+ Class representing a datastore with restricted HDF5 format.
4341
4442 Attributes:
45- filepath (pathlib.Path): The path to the HDF5 file where the model data will be saved.
43+ groups (Dict[str,Group]): groups of data.
44+ attrs (Attrs): attributes of the datastore.
4645 """
4746
4847 groups : Dict [str , Any ]
@@ -51,6 +50,7 @@ class Datastore(BaseModel, extra="forbid"):
5150
5251 @classmethod
5352 def _validate_group (cls , key , group ):
53+ """Helper function for validating group to be of type Group registered in the GroupRegistry."""
5454 if isinstance (group , GroupBase ):
5555 return group
5656
@@ -62,6 +62,7 @@ def _validate_group(cls, key, group):
6262 @field_validator ("groups" , mode = "before" )
6363 @classmethod
6464 def validate_groups (cls , data ):
65+ """Validates groups to be of type Group registered in the GroupRegistry."""
6566 if GroupRegistry .groups == {}:
6667 raise ValueError (
6768 "No group types available. Register group types before creating Datastore."
@@ -71,6 +72,7 @@ def validate_groups(cls, data):
7172 return validated_groups
7273
7374 def _dump_group (self , h5datastore , gkey , group ):
75+ """Helper function for dumping Group."""
7476 # remove existing group
7577 if gkey in h5datastore .keys ():
7678 del h5datastore [gkey ]
@@ -98,6 +100,7 @@ def _dump_group(self, h5datastore, gkey, group):
98100 self ._dump_dataset (h5_group , dkey , dataset )
99101
100102 def _dump_dataset (self , h5group , dkey , dataset ):
103+ """Helper function for dumping Dataset."""
101104 if not isinstance (dataset , Dataset ):
102105 raise ValueError ("Group data field is not a Dataset." )
103106
@@ -177,7 +180,26 @@ def model_validate_hdf5(cls, filepath: pathlib.Path):
177180 return self
178181
179182 def __getitem__ (self , key ):
183+ """Overloads indexing to retrieve elements in groups."""
180184 return self .groups .__getitem__ (key )
181185
182186 def __iter__ (self ):
187+ """Overloads iter to iterate over elements in groups."""
183188 return self .groups .items ().__iter__ ()
189+
190+ def add (self , ** groups ):
191+ """Adds a new groups to the datastore."""
192+ for k , v in groups .items ():
193+ if k in self .groups .keys ():
194+ raise ValueError (
195+ "Key already exist in the datastore, use `update` instead if intending to overwrite past data."
196+ )
197+ self .groups [k ] == v
198+
199+ def update (self , ** groups ):
200+ """Updates groups in the datastore, overwriting past values."""
201+ for k , v in groups .items ():
202+ self .groups [k ] == v
203+
204+
205+ # %%
0 commit comments