Skip to content

Commit 92a5719

Browse files
committed
[test] Added more test of datasets, datastores and groups
1 parent f422b43 commit 92a5719

File tree

5 files changed

+577
-79
lines changed

5 files changed

+577
-79
lines changed

tests/test_dataset.py

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
# Copyright 2024-2025 Open Quantum Design
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# %%
16+
17+
import numpy as np
18+
import pytest
19+
from pydantic import TypeAdapter
20+
21+
from oqd_dataschema.base import CastDataset, Dataset, DTypes, condataset
22+
23+
########################################################################################
24+
25+
26+
class TestDatasetDtype:
27+
@pytest.mark.parametrize(
28+
("dtype", "np_dtype"),
29+
[
30+
("bool", np.dtypes.BoolDType),
31+
("int16", np.dtypes.Int16DType),
32+
("int32", np.dtypes.Int32DType),
33+
("int64", np.dtypes.Int64DType),
34+
("uint16", np.dtypes.UInt16DType),
35+
("uint32", np.dtypes.UInt32DType),
36+
("uint64", np.dtypes.UInt64DType),
37+
("float16", np.dtypes.Float16DType),
38+
("float32", np.dtypes.Float32DType),
39+
("float64", np.dtypes.Float64DType),
40+
("complex64", np.dtypes.Complex64DType),
41+
("complex128", np.dtypes.Complex128DType),
42+
("str", np.dtypes.StrDType),
43+
("bytes", np.dtypes.BytesDType),
44+
("string", np.dtypes.StringDType),
45+
],
46+
)
47+
def test_dtypes(self, dtype, np_dtype):
48+
ds = Dataset(dtype=dtype, shape=(100,))
49+
50+
data = np.random.rand(100).astype(np_dtype)
51+
ds.data = data
52+
53+
@pytest.mark.xfail(raises=ValueError)
54+
@pytest.mark.parametrize("dtype", list(DTypes.names()))
55+
def test_unmatched_dtype_data(self, dtype):
56+
ds = Dataset(dtype=dtype, shape=(100,))
57+
58+
data = np.random.rand(100).astype("O")
59+
ds.data = data
60+
61+
@pytest.mark.parametrize("dtype", list(DTypes.names()))
62+
def test_flexible_dtype(self, dtype):
63+
ds = Dataset(dtype=None, shape=(100,))
64+
65+
data = np.random.rand(100).astype(DTypes.get(dtype).value)
66+
ds.data = data
67+
68+
assert ds.dtype == DTypes(type(ds.data.dtype)).name.lower()
69+
70+
def test_dtype_mutation(self):
71+
ds = Dataset(dtype="float32", shape=(100,))
72+
73+
ds.dtype = "float64"
74+
75+
data = np.random.rand(100)
76+
ds.data = data
77+
78+
79+
class TestDatasetShape:
80+
@pytest.mark.xfail(raises=ValueError)
81+
@pytest.mark.parametrize(
82+
"shape",
83+
[
84+
(0,),
85+
(1,),
86+
(99,),
87+
(1, 1),
88+
],
89+
)
90+
def test_unmatched_shape_data(self, shape):
91+
ds = Dataset(dtype="float64", shape=(100,))
92+
93+
data = np.random.rand(*shape)
94+
ds.data = data
95+
96+
@pytest.mark.parametrize(
97+
("shape", "data_shape"),
98+
[
99+
((None,), (0,)),
100+
((None,), (1,)),
101+
((None,), (100,)),
102+
((None, 0), (0, 0)),
103+
((None, 1), (1, 1)),
104+
((None, None), (1, 1)),
105+
((None, None), (10, 100)),
106+
((None, None, 1), (1, 1, 1)),
107+
],
108+
)
109+
def test_flexible_shape(self, shape, data_shape):
110+
ds = Dataset(dtype="float64", shape=shape)
111+
112+
data = np.random.rand(*data_shape)
113+
ds.data = data
114+
115+
assert ds.shape == ds.data.shape
116+
117+
def test_shape_mutation(self):
118+
ds = Dataset(dtype="float64", shape=(1,))
119+
120+
ds.shape = (100,)
121+
122+
data = np.random.rand(100)
123+
ds.data = data
124+
125+
126+
class TestCastDataset:
127+
@pytest.fixture
128+
def adapter(self):
129+
return TypeAdapter(CastDataset)
130+
131+
@pytest.mark.parametrize(
132+
("data", "dtype", "shape"),
133+
[
134+
(np.random.rand(100), "float64", (100,)),
135+
(np.random.rand(10).astype("str"), "str", (10,)),
136+
(np.random.rand(1, 10, 100).astype("bytes"), "bytes", (1, 10, 100)),
137+
],
138+
)
139+
def test_cast(self, adapter, data, shape, dtype):
140+
ds = adapter.validate_python(data)
141+
142+
assert ds.shape == shape and ds.dtype == dtype
143+
144+
145+
class TestConstrainedDataset:
146+
@pytest.mark.parametrize(
147+
("cds", "data"),
148+
[
149+
(condataset(dtype_constraint="float64"), np.random.rand(10)),
150+
(condataset(dtype_constraint="str"), np.random.rand(10).astype(str)),
151+
(
152+
condataset(dtype_constraint=("float16", "float32", "float64")),
153+
np.random.rand(10),
154+
),
155+
(
156+
condataset(dtype_constraint=("float16", "float32", "float64")),
157+
np.random.rand(10).astype("float16"),
158+
),
159+
(
160+
condataset(dtype_constraint=("float16", "float32", "float64")),
161+
np.random.rand(10).astype("float32"),
162+
),
163+
],
164+
)
165+
def test_constrained_dataset_dtype(self, cds, data):
166+
adapter = TypeAdapter(cds)
167+
168+
adapter.validate_python(data)
169+
170+
@pytest.mark.xfail(raises=ValueError)
171+
@pytest.mark.parametrize(
172+
("cds", "data"),
173+
[
174+
(condataset(dtype_constraint="float64"), np.random.rand(10).astype(str)),
175+
(condataset(dtype_constraint="str"), np.random.rand(10)),
176+
(
177+
condataset(dtype_constraint=("float16", "float32", "float64")),
178+
np.random.rand(10).astype(str),
179+
),
180+
],
181+
)
182+
def test_violate_dtype_constraint(self, cds, data):
183+
adapter = TypeAdapter(cds)
184+
185+
adapter.validate_python(data)
186+
187+
@pytest.mark.parametrize(
188+
("cds", "data"),
189+
[
190+
(condataset(min_dim=1, max_dim=1), np.random.rand(10)),
191+
(condataset(min_dim=0, max_dim=1), np.random.rand(10)),
192+
(condataset(max_dim=2), np.random.rand(10)),
193+
(condataset(max_dim=3), np.random.rand(10, 10, 10)),
194+
(condataset(min_dim=2), np.random.rand(10, 10)),
195+
(condataset(min_dim=2), np.random.rand(10, 10, 10, 10, 10)),
196+
(condataset(min_dim=2, max_dim=4), np.random.rand(10, 10, 10, 10)),
197+
(condataset(min_dim=2, max_dim=4), np.random.rand(10, 10, 10)),
198+
(condataset(min_dim=2, max_dim=4), np.random.rand(10, 10)),
199+
],
200+
)
201+
def test_constrained_dataset_dimension(self, cds, data):
202+
adapter = TypeAdapter(cds)
203+
204+
adapter.validate_python(data)
205+
206+
@pytest.mark.xfail(raises=ValueError)
207+
@pytest.mark.parametrize(
208+
("cds", "data"),
209+
[
210+
(condataset(min_dim=1, max_dim=1), np.random.rand(10, 10)),
211+
(condataset(min_dim=2, max_dim=3), np.random.rand(10)),
212+
(condataset(min_dim=2, max_dim=3), np.random.rand(10, 10, 10, 10)),
213+
],
214+
)
215+
def test_violate_dimension_constraint(self, cds, data):
216+
adapter = TypeAdapter(cds)
217+
218+
adapter.validate_python(data)
219+
220+
@pytest.mark.parametrize(
221+
("cds", "data"),
222+
[
223+
(condataset(shape_constraint=(None,)), np.random.rand(10)),
224+
(condataset(shape_constraint=(10,)), np.random.rand(10)),
225+
(condataset(shape_constraint=(None, None)), np.random.rand(1, 2)),
226+
(condataset(shape_constraint=(1, None)), np.random.rand(1, 2)),
227+
(condataset(shape_constraint=(1, 2)), np.random.rand(1, 2)),
228+
(condataset(shape_constraint=(1, None, 3)), np.random.rand(1, 10, 3)),
229+
],
230+
)
231+
def test_constrained_dataset_shape(self, cds, data):
232+
adapter = TypeAdapter(cds)
233+
234+
adapter.validate_python(data)
235+
236+
@pytest.mark.xfail(raises=ValueError)
237+
@pytest.mark.parametrize(
238+
("cds", "data"),
239+
[
240+
(condataset(shape_constraint=(1,)), np.random.rand(10)),
241+
(condataset(shape_constraint=(None,)), np.random.rand(10, 10)),
242+
(condataset(shape_constraint=(None, 1)), np.random.rand(10, 10)),
243+
(condataset(shape_constraint=(None, 1)), np.random.rand(1, 10)),
244+
],
245+
)
246+
def test_violate_shape_constraint(self, cds, data):
247+
adapter = TypeAdapter(cds)
248+
249+
adapter.validate_python(data)

tests/test_datastore.py

Lines changed: 81 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,44 +13,97 @@
1313
# limitations under the License.
1414

1515
# %%
16-
import pathlib
16+
import uuid
17+
from typing import Dict, Optional
1718

1819
import numpy as np
1920
import pytest
2021

21-
from oqd_dataschema.base import Dataset, DTypes
22-
from oqd_dataschema.datastore import Datastore
23-
from oqd_dataschema.groups import (
24-
SinaraRawDataGroup,
25-
)
26-
22+
from oqd_dataschema import Datastore, GroupBase
23+
from oqd_dataschema.base import Dataset
2724

2825
# %%
29-
@pytest.mark.parametrize(
30-
"dtype",
31-
[
32-
"int32",
33-
"int64",
34-
"float32",
35-
"float64",
36-
"complex64",
37-
"complex128",
38-
],
26+
27+
_Group = type(
28+
f"_Group_{uuid.uuid4()}".replace("-", ""),
29+
(GroupBase,),
30+
{
31+
"__annotations__": {
32+
"x": Dataset,
33+
"y": Dict[str, Dataset],
34+
"z": Optional[Dataset],
35+
},
36+
"y": {},
37+
"z": None,
38+
},
3939
)
40-
def test_serialize_deserialize(dtype):
41-
data = np.ones([10, 10]).astype(dtype)
42-
dataset = SinaraRawDataGroup(camera_images=Dataset(data=data))
43-
data = Datastore(groups={"test": dataset})
4440

45-
filepath = pathlib.Path("test.h5")
46-
data.model_dump_hdf5(filepath)
4741

48-
data_reload = Datastore.model_validate_hdf5(filepath)
42+
class TestDatastore:
43+
@pytest.mark.parametrize(
44+
("dtype", "np_dtype"),
45+
[
46+
("bool", np.dtypes.BoolDType),
47+
("int16", np.dtypes.Int16DType),
48+
("int32", np.dtypes.Int32DType),
49+
("int64", np.dtypes.Int64DType),
50+
("uint16", np.dtypes.UInt16DType),
51+
("uint32", np.dtypes.UInt32DType),
52+
("uint64", np.dtypes.UInt64DType),
53+
("float16", np.dtypes.Float16DType),
54+
("float32", np.dtypes.Float32DType),
55+
("float64", np.dtypes.Float64DType),
56+
("complex64", np.dtypes.Complex64DType),
57+
("complex128", np.dtypes.Complex128DType),
58+
("str", np.dtypes.StrDType),
59+
("bytes", np.dtypes.BytesDType),
60+
("string", np.dtypes.StringDType),
61+
],
62+
)
63+
def test_serialize_deserialize_dtypes(self, dtype, np_dtype, tmp_path):
64+
f = tmp_path / f"tmp{uuid.uuid4()}.h5"
65+
66+
datastore = Datastore(
67+
groups={"g1": _Group(x=Dataset(data=np.random.rand(1).astype(np_dtype)))}
68+
)
4969

50-
assert (
51-
type(data_reload.groups["test"].camera_images.data.dtype)
52-
is DTypes.get(dtype).value
70+
datastore.model_dump_hdf5(f)
71+
72+
Datastore.model_validate_hdf5(f)
73+
74+
@pytest.mark.parametrize(
75+
("x", "y", "z"),
76+
[
77+
(
78+
Dataset(data=np.random.rand(10)),
79+
{},
80+
None,
81+
),
82+
(
83+
Dataset(data=np.random.rand(10)),
84+
{"f1": Dataset(data=np.random.rand(10))},
85+
None,
86+
),
87+
(
88+
Dataset(data=np.random.rand(10)),
89+
{"f1": Dataset(data=np.random.rand(10))},
90+
Dataset(data=np.random.rand(10)),
91+
),
92+
(
93+
Dataset(data=np.random.rand(10)),
94+
{
95+
"f1": Dataset(data=np.random.rand(10)),
96+
"f2": Dataset(data=np.random.rand(10)),
97+
},
98+
Dataset(data=np.random.rand(10)),
99+
),
100+
],
53101
)
102+
def test_serialize_deserialize_dataset_types(self, x, y, z, tmp_path):
103+
f = tmp_path / f"tmp{uuid.uuid4()}.h5"
54104

105+
datastore = Datastore(groups={"g1": _Group(x=x, y=y, z=z)})
55106

56-
# %%
107+
datastore.model_dump_hdf5(f)
108+
109+
Datastore.model_validate_hdf5(f)

0 commit comments

Comments
 (0)