@@ -49,69 +49,69 @@ class Datastore(BaseModel, extra="forbid"):
4949
5050 attrs : Attrs = {}
5151
52- @field_validator ("groups" , mode = "before" )
5352 @classmethod
54- def validate_groups (cls , data ):
55- if isinstance (data , dict ):
56- # Get the current adapter from registry
57- try :
58- validated_groups = {}
59-
60- for key , group_data in data .items ():
61- if isinstance (group_data , GroupBase ):
62- # Already a Group instance
63- validated_groups [key ] = group_data
64- elif isinstance (group_data , dict ):
65- # Parse dict using discriminated union
66- validated_groups [key ] = GroupRegistry .adapter .validate_python (
67- group_data
68- )
69- else :
70- raise ValueError (
71- f"Invalid group data for key '{ key } ': { type (group_data )} "
72- )
53+ def _validate_group (cls , key , group ):
54+ if isinstance (group , GroupBase ):
55+ return group
7356
74- data = validated_groups
57+ if isinstance (group , dict ):
58+ return GroupRegistry .adapter .validate_python (group )
7559
76- except ValueError as e :
77- if "No group types registered" in str (e ):
78- raise ValueError (
79- "No group types available. Register group types before creating Datastore."
80- )
81- raise
60+ raise ValueError (f"Key `{ key } ` contains invalid group data." )
8261
83- return data
62+ @field_validator ("groups" , mode = "before" )
63+ @classmethod
64+ def validate_groups (cls , data ):
65+ if GroupRegistry .groups == {}:
66+ raise ValueError (
67+ "No group types available. Register group types before creating Datastore."
68+ )
69+
70+ validated_groups = {k : cls ._validate_group (k , v ) for k , v in data .items ()}
71+ return validated_groups
8472
8573 def _dump_group (self , h5datastore , gkey , group ):
74+ # remove existing group
8675 if gkey in h5datastore .keys ():
8776 del h5datastore [gkey ]
77+
78+ # create group
8879 h5_group = h5datastore .create_group (gkey )
8980
81+ # dump group schema
9082 h5_group .attrs ["_group_schema" ] = json .dumps (
9183 group .model_json_schema (), indent = 2
9284 )
85+
86+ # dump group attributes
9387 for akey , attr in group .attrs .items ():
9488 h5_group .attrs [akey ] = attr
9589
90+ # dump group data
9691 for dkey , dataset in group .__dict__ .items ():
92+ # if group field contain dictionary of Dataset
9793 if isinstance (dataset , dict ):
9894 h5_subgroup = h5_group .create_group (dkey )
9995 for ddkey , ddataset in dataset .items ():
10096 self ._dump_dataset (h5_subgroup , ddkey , ddataset )
101-
102- self ._dump_dataset (h5_group , dkey , dataset )
97+ else :
98+ self ._dump_dataset (h5_group , dkey , dataset )
10399
104100 def _dump_dataset (self , h5group , dkey , dataset ):
105- if isinstance (dataset , Dataset ):
106- if dataset .dtype in "str" :
107- h5_dataset = h5group .create_dataset (
108- dkey , data = dataset .data .astype (np .dtypes .BytesDType )
109- )
110- else :
111- h5_dataset = h5group .create_dataset (dkey , data = dataset .data )
101+ if not isinstance (dataset , Dataset ):
102+ raise ValueError ("Group data field is not a Dataset." )
103+
104+ # dtype str converted to bytes when dumped (h5 compatibility)
105+ if dataset .dtype in "str" :
106+ h5_dataset = h5group .create_dataset (
107+ dkey , data = dataset .data .astype (np .dtypes .BytesDType )
108+ )
109+ else :
110+ h5_dataset = h5group .create_dataset (dkey , data = dataset .data )
112111
113- for akey , attr in dataset .attrs .items ():
114- h5_dataset .attrs [akey ] = attr
112+ # dump dataset attributes
113+ for akey , attr in dataset .attrs .items ():
114+ h5_dataset .attrs [akey ] = attr
115115
116116 def model_dump_hdf5 (self , filepath : pathlib .Path , mode : Literal ["w" , "a" ] = "w" ):
117117 """
@@ -125,12 +125,12 @@ def model_dump_hdf5(self, filepath: pathlib.Path, mode: Literal["w", "a"] = "w")
125125 filepath .parent .mkdir (exist_ok = True , parents = True )
126126
127127 with h5py .File (filepath , mode ) as f :
128- # store the model JSON schema
128+ # dump the datastore signature
129129 f .attrs ["_datastore_signature" ] = self .model_dump_json (indent = 2 )
130130 for akey , attr in self .attrs .items ():
131131 f .attrs [akey ] = attr
132132
133- # store each group
133+ # dump each group
134134 for gkey , group in self .groups .items ():
135135 self ._dump_group (f , gkey , group )
136136
@@ -143,26 +143,36 @@ def model_validate_hdf5(cls, filepath: pathlib.Path):
143143 filepath (pathlib.Path): The path to the HDF5 file where the model data will be read and validated from.
144144 """
145145 with h5py .File (filepath , "r" ) as f :
146+ # Load datastore signature
146147 self = cls .model_validate_json (f .attrs ["_datastore_signature" ])
147148
148- # loop through all groups in the model schema and load HDF5 store
149+ # loop through all groups in the model schema and load the data
149150 for gkey , group in self :
150151 for dkey in group .__class__ .model_fields :
152+ # ignore attrs and class_ fields
151153 if dkey in ("attrs" , "class_" ):
152154 continue
153155
156+ # load Dataset data
154157 if isinstance (group .__dict__ [dkey ], Dataset ):
155158 group .__dict__ [dkey ].data = np .array (f [gkey ][dkey ][()]).astype (
156159 DTypes .get (group .__dict__ [dkey ].dtype ).value
157160 )
161+ continue
158162
163+ # load data for dict of Dataset
159164 if isinstance (group .__dict__ [dkey ], dict ):
160165 for ddkey in group .__dict__ [dkey ]:
161166 group .__dict__ [dkey ][ddkey ].data = np .array (
162167 f [gkey ][dkey ][ddkey ][()]
163168 ).astype (
164169 DTypes .get (group .__dict__ [dkey ][ddkey ].dtype ).value
165170 )
171+ continue
172+
173+ raise TypeError (
174+ "Group data fields must be of type Dataset or dict of Dataset."
175+ )
166176
167177 return self
168178
0 commit comments