Skip to content

Commit 0bc7790

Browse files
committed
[fix] serialization of str to hdf5. str are serialized as bytes during dump and casted back to str at validation of hdf5.
1 parent 6d66226 commit 0bc7790

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

src/oqd_dataschema/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _valid_attr_key(value):
6464
"float64": np.dtypes.Float64DType,
6565
"complex64": np.dtypes.Complex64DType,
6666
"complex128": np.dtypes.Complex128DType,
67-
"string": np.dtypes.StrDType,
67+
"str": np.dtypes.StrDType,
6868
"bytes": np.dtypes.BytesDType,
6969
"bool": np.dtypes.BoolDType,
7070
}

src/oqd_dataschema/datastore.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
from pydantic.types import TypeVar
2828

29-
from oqd_dataschema.base import Attrs, Dataset, GroupBase, GroupRegistry
29+
from oqd_dataschema.base import Attrs, Dataset, GroupBase, GroupRegistry, dtype_map
3030

3131
########################################################################################
3232

@@ -113,7 +113,12 @@ def model_dump_hdf5(self, filepath: pathlib.Path, mode: Literal["w", "a"] = "a")
113113
for dkey, dataset in group.__dict__.items():
114114
if not isinstance(dataset, Dataset):
115115
continue
116-
h5_dataset = h5_group.create_dataset(dkey, data=dataset.data)
116+
h5_dataset = h5_group.create_dataset(
117+
dkey,
118+
data=dataset.data.astype(np.dtypes.BytesDType)
119+
if dataset.dtype == "str"
120+
else dataset.data,
121+
)
117122
for akey, attr in dataset.attrs.items():
118123
h5_dataset.attrs[akey] = attr
119124

@@ -135,7 +140,9 @@ def model_validate_hdf5(
135140
for dkey in group.__class__.model_fields:
136141
if dkey in ("attrs", "class_"):
137142
continue
138-
group.__dict__[dkey].data = np.array(f[gkey][dkey][()])
143+
group.__dict__[dkey].data = np.array(f[gkey][dkey][()]).astype(
144+
dtype_map[group.__dict__[dkey].dtype]
145+
)
139146
return self
140147

141148
def __getitem__(self, key):

0 commit comments

Comments
 (0)