Skip to content

Commit 453bf96

Browse files
committed
[fix] Saving and loading of optional datasets
1 parent b76239e commit 453bf96

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

src/oqd_dataschema/datastore.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class Datastore(BaseModel, extra="forbid"):
4444
attrs (Attrs): attributes of the datastore.
4545
"""
4646

47-
groups: Dict[str, Any]
47+
groups: Dict[str, Any] = {}
4848

4949
attrs: Attrs = {}
5050

@@ -91,7 +91,7 @@ def _dump_group(self, h5datastore, gkey, group):
9191

9292
# dump group data
9393
for dkey, dataset in group.__dict__.items():
94-
if dkey in ["attr", "class_"]:
94+
if dkey in ["attrs", "class_"]:
9595
continue
9696

9797
# if group field contain dictionary of Dataset
@@ -105,9 +105,15 @@ def _dump_group(self, h5datastore, gkey, group):
105105

106106
def _dump_dataset(self, h5group, dkey, dataset):
107107
"""Helper function for dumping Dataset."""
108-
if not isinstance(dataset, Dataset):
108+
109+
if dataset is not None and not isinstance(dataset, Dataset):
109110
raise ValueError("Group data field is not a Dataset.")
110111

112+
# handle optional dataset
113+
if dataset is None:
114+
h5_dataset = h5group.create_dataset(dkey, data=h5py.Empty("f"))
115+
return
116+
111117
# dtype str converted to bytes when dumped (h5 compatibility)
112118
if dataset.dtype in "str":
113119
h5_dataset = h5group.create_dataset(
@@ -139,7 +145,7 @@ def model_dump_hdf5(self, filepath: pathlib.Path, mode: Literal["w", "a"] = "w")
139145

140146
# dump each group
141147
for gkey, group in self.groups.items():
142-
if gkey in ["attr", "class_"]:
148+
if gkey in ["attrs", "class_"]:
143149
continue
144150

145151
self._dump_group(f, gkey, group)
@@ -163,6 +169,9 @@ def model_validate_hdf5(cls, filepath: pathlib.Path):
163169
if dkey in ("attrs", "class_"):
164170
continue
165171

172+
if group.__dict__[dkey] is None:
173+
continue
174+
166175
# load Dataset data
167176
if isinstance(group.__dict__[dkey], Dataset):
168177
group.__dict__[dkey].data = np.array(f[gkey][dkey][()]).astype(

0 commit comments

Comments
 (0)