Skip to content

Commit f422b43

Browse files
committed
[feat] Implemented flex shape support and refactored validators
1 parent a5c7ad3 commit f422b43

File tree

3 files changed

+143
-78
lines changed

3 files changed

+143
-78
lines changed

src/oqd_dataschema/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .base import Dataset, GroupBase, GroupRegistry, condataset
15+
from .base import CastDataset, Dataset, GroupBase, GroupRegistry, condataset
1616
from .datastore import Datastore
1717
from .groups import (
1818
ExpectationValueDataGroup,
@@ -24,6 +24,7 @@
2424
########################################################################################
2525

2626
__all__ = [
27+
"CastDataset",
2728
"Dataset",
2829
"Datastore",
2930
"GroupBase",

src/oqd_dataschema/base.py

Lines changed: 98 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import typing
1717
import warnings
1818
from enum import Enum
19-
from functools import partial, reduce
19+
from types import NoneType
2020
from typing import Annotated, Any, ClassVar, Literal, Optional, Sequence, Tuple, Union
2121

2222
import numpy as np
@@ -28,9 +28,12 @@
2828
Discriminator,
2929
Field,
3030
TypeAdapter,
31+
field_validator,
3132
model_validator,
3233
)
3334

35+
from .utils import _flex_shape_equal, _validator_from_condition
36+
3437
########################################################################################
3538

3639
__all__ = ["GroupBase", "Dataset", "GroupRegistry", "condataset", "CastDataset"]
@@ -108,7 +111,7 @@ class Dataset(BaseModel, extra="forbid"):
108111
"""
109112

110113
dtype: Optional[Literal[DTypes.names()]] = None # type: ignore
111-
shape: Optional[Tuple[int, ...]] = None
114+
shape: Optional[Tuple[Union[int, None], ...]] = None
112115
data: Optional[Any] = Field(default=None, exclude=True)
113116

114117
attrs: Attrs = {}
@@ -117,51 +120,58 @@ class Dataset(BaseModel, extra="forbid"):
117120
use_enum_values=False, arbitrary_types_allowed=True, validate_assignment=True
118121
)
119122

123+
@field_validator("data", mode="before")
120124
@classmethod
121-
def cast(cls, data):
122-
if isinstance(data, np.ndarray):
123-
return cls(data=data)
124-
return data
125+
def validate_and_update(cls, value):
126+
# check if data exist
127+
if value is None:
128+
return value
125129

126-
@model_validator(mode="before")
127-
@classmethod
128-
def validate_and_update(cls, values: dict):
129-
data = values.get("data")
130-
dtype = values.get("dtype")
131-
shape = values.get("shape")
130+
# check if data is a numpy array
131+
if not isinstance(value, np.ndarray):
132+
raise TypeError("`data` must be a numpy.ndarray.")
132133

133-
if data is None and (dtype is not None and shape is not None):
134-
return values
134+
return value
135135

136-
elif data is not None and (dtype is None and shape is None):
137-
if not isinstance(data, np.ndarray):
138-
raise TypeError("`data` must be a numpy.ndarray.")
136+
@model_validator(mode="after")
137+
def validate_data_matches_shape_dtype(self):
138+
"""Ensure that `data` matches `dtype` and `shape`."""
139139

140-
if type(data.dtype) not in DTypes:
141-
raise TypeError(
142-
f"`data` must be a numpy array of dtype in {tuple(DTypes.names())}."
143-
)
140+
# check if data exist
141+
if self.data is None:
142+
return self
143+
144+
# check if dtype matches data
145+
if (
146+
self.dtype is not None
147+
and type(self.data.dtype) is not DTypes.get(self.dtype).value
148+
):
149+
raise ValueError(
150+
f"Expected data dtype `{self.dtype}`, but got `{self.data.dtype.name}`."
151+
)
144152

145-
values["dtype"] = DTypes(type(data.dtype)).name.lower()
146-
values["shape"] = data.shape
153+
# check if shape mataches data
154+
if self.shape is not None and not _flex_shape_equal(
155+
self.data.shape, self.shape
156+
):
157+
raise ValueError(f"Expected shape {self.shape}, but got {self.data.shape}.")
147158

148-
return values
159+
# reassign dtype if it is None
160+
if self.dtype != DTypes(type(self.data.dtype)).name.lower():
161+
self.dtype = DTypes(type(self.data.dtype)).name.lower()
162+
163+
# resassign shape to concrete value if it is None or a flexible shape
164+
if self.shape != self.data.shape:
165+
self.shape = self.data.shape
149166

150-
@model_validator(mode="after")
151-
def validate_data_matches_shape_dtype(self):
152-
"""Ensure that `data` matches `dtype` and `shape`."""
153-
if self.data is not None:
154-
expected_dtype = DTypes.get(self.dtype).value
155-
if type(self.data.dtype) is not expected_dtype:
156-
raise ValueError(
157-
f"Expected data dtype `{self.dtype}`, but got `{self.data.dtype.name}`."
158-
)
159-
if self.data.shape != self.shape:
160-
raise ValueError(
161-
f"Expected shape {self.shape}, but got {self.data.shape}."
162-
)
163167
return self
164168

169+
@classmethod
170+
def cast(cls, data):
171+
if isinstance(data, np.ndarray):
172+
return cls(data=data)
173+
return data
174+
165175
def __getitem__(self, idx):
166176
return self.data[idx]
167177

@@ -172,6 +182,12 @@ def _is_dataset_type(cls, type_):
172182
)
173183

174184

185+
CastDataset = Annotated[Dataset, BeforeValidator(Dataset.cast)]
186+
187+
########################################################################################
188+
189+
190+
@_validator_from_condition
175191
def _constrain_dtype(dataset, *, dtype_constraint=None):
176192
if (not isinstance(dtype_constraint, str)) and isinstance(
177193
dtype_constraint, Sequence
@@ -185,10 +201,12 @@ def _constrain_dtype(dataset, *, dtype_constraint=None):
185201
f"Expected dtype to be of type one of {dtype_constraint}, but got {dataset.dtype}."
186202
)
187203

188-
return dataset
189-
190204

205+
@_validator_from_condition
191206
def _constraint_dim(dataset, *, min_dim=None, max_dim=None):
207+
if min_dim is not None and max_dim is not None and min_dim > max_dim:
208+
raise ValueError("Impossible to satisfy dimension constraints on dataset.")
209+
192210
min_dim = 0 if min_dim is None else min_dim
193211

194212
dims = len(dataset.shape)
@@ -198,42 +216,26 @@ def _constraint_dim(dataset, *, min_dim=None, max_dim=None):
198216
f"Expected {min_dim} <= dimension of shape{f' <= {max_dim}'}, but got shape = {dataset.shape}."
199217
)
200218

201-
return dataset
202-
203219

220+
@_validator_from_condition
204221
def _constraint_shape(dataset, *, shape_constraint=None):
205-
if shape_constraint and (
206-
len(shape_constraint) != len(dataset.shape)
207-
or reduce(
208-
lambda x, y: x or y,
209-
map(
210-
lambda x: x[0] is not None and x[0] != x[1],
211-
zip(shape_constraint, dataset.shape),
212-
),
213-
)
214-
):
222+
if shape_constraint and not _flex_shape_equal(shape_constraint, dataset.shape):
215223
raise ValueError(
216224
f"Expected shape to be {shape_constraint}, but got {dataset.shape}."
217225
)
218226

219-
return dataset
220-
221227

222228
def condataset(
223229
*, shape_constraint=None, dtype_constraint=None, min_dim=None, max_dim=None
224230
):
225231
return Annotated[
226-
Dataset,
227-
BeforeValidator(Dataset.cast),
228-
AfterValidator(partial(_constrain_dtype, dtype_constraint=dtype_constraint)),
229-
AfterValidator(partial(_constraint_dim, min_dim=min_dim, max_dim=max_dim)),
230-
AfterValidator(partial(_constraint_shape, shape_constraint=shape_constraint)),
232+
CastDataset,
233+
AfterValidator(_constrain_dtype(dtype_constraint=dtype_constraint)),
234+
AfterValidator(_constraint_dim(min_dim=min_dim, max_dim=max_dim)),
235+
AfterValidator(_constraint_shape(shape_constraint=shape_constraint)),
231236
]
232237

233238

234-
CastDataset = Annotated[Dataset, BeforeValidator(Dataset.cast)]
235-
236-
237239
########################################################################################
238240

239241

@@ -255,27 +257,46 @@ class GroupBase(BaseModel, extra="forbid"):
255257

256258
attrs: Attrs = {}
257259

260+
@classmethod
261+
def _is_allowed_field_type(cls, v):
262+
is_dataset = Dataset._is_dataset_type(v)
263+
264+
is_annotated_dataset = typing.get_origin(
265+
v
266+
) is Annotated and Dataset._is_dataset_type(v.__origin__)
267+
268+
is_optional_dataset = typing.get_origin(v) is Union and (
269+
(v.__args__[0] == NoneType and Dataset._is_dataset_type(v.__args__[1]))
270+
or (v.__args__[1] == NoneType and Dataset._is_dataset_type(v.__args__[0]))
271+
)
272+
273+
is_dict_dataset = (
274+
typing.get_origin(v) is dict
275+
and v.__args__[0] is str
276+
and Dataset._is_dataset_type(v.__args__[1])
277+
)
278+
279+
return (
280+
is_dataset or is_annotated_dataset or is_optional_dataset or is_dict_dataset
281+
)
282+
283+
@classmethod
284+
def _is_classvar(cls, v):
285+
return v is ClassVar or typing.get_origin(v) is ClassVar
286+
258287
def __init_subclass__(cls, **kwargs):
259288
super().__init_subclass__(**kwargs)
260289

261290
for k, v in cls.__annotations__.items():
262-
if k == "class_":
263-
raise AttributeError("`class_` attribute should not be set manually.")
264-
265-
if k == "attrs" and k is not Attrs:
266-
raise TypeError("`attrs` should be of type `Attrs`")
267-
268-
if (
269-
k not in ["class_", "attrs"]
270-
and v is not ClassVar
271-
and not Dataset._is_dataset_type(v)
272-
and not (typing.get_origin(v) is Annotated and v.__origin__ is Dataset)
273-
and not (
274-
typing.get_origin(v) is dict
275-
and v.__args__[0] is str
276-
and Dataset._is_dataset_type(v.__args__[1])
291+
if k in ["class_", "attrs"]:
292+
raise AttributeError(
293+
"`class_` and `attrs` attribute should not be set manually."
277294
)
278-
):
295+
296+
if cls._is_classvar(v):
297+
continue
298+
299+
if not cls._is_allowed_field_type(v):
279300
raise TypeError(
280301
"All fields of `GroupBase` have to be of type `Dataset`."
281302
)

src/oqd_dataschema/utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright 2024-2025 Open Quantum Design
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from functools import reduce
16+
17+
########################################################################################
18+
19+
__all__ = ["_flex_shape_equal", "_validator_from_condition"]
20+
21+
22+
########################################################################################
23+
24+
25+
def _flex_shape_equal(shape1, shape2):
26+
return len(shape1) == len(shape2) and reduce(
27+
lambda x, y: x and y,
28+
map(
29+
lambda x: x[0] is None or x[1] is None or x[0] == x[1],
30+
zip(shape1, shape2),
31+
),
32+
)
33+
34+
35+
def _validator_from_condition(f):
36+
def _wrapped_validator(*args, **kwargs):
37+
def _wrapped_condition(model):
38+
f(model, *args, **kwargs)
39+
return model
40+
41+
return _wrapped_condition
42+
43+
return _wrapped_validator

0 commit comments

Comments
 (0)