Skip to content

Commit f5a9276

Browse files
committed
[feat] Implemented CastDataset that cast a numpy array automatically to a Dataset
1 parent 850e872 commit f5a9276

File tree

3 files changed

+18
-7
lines changed

3 files changed

+18
-7
lines changed

src/oqd_dataschema/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,5 @@
3333
"OQDTestbenchDataGroup",
3434
"SinaraRawDataGroup",
3535
"condataset",
36+
"CastDataset",
3637
]

src/oqd_dataschema/base.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
########################################################################################
3535

36-
__all__ = ["GroupBase", "Dataset", "GroupRegistry", "condataset"]
36+
__all__ = ["GroupBase", "Dataset", "GroupRegistry", "condataset", "CastDataset"]
3737

3838
########################################################################################
3939

@@ -117,6 +117,12 @@ class Dataset(BaseModel, extra="forbid"):
117117
use_enum_values=False, arbitrary_types_allowed=True, validate_assignment=True
118118
)
119119

120+
@classmethod
121+
def cast(cls, data):
122+
if isinstance(data, np.ndarray):
123+
return cls(data=data)
124+
return data
125+
120126
@model_validator(mode="before")
121127
@classmethod
122128
def validate_and_update(cls, values: dict):
@@ -212,12 +218,16 @@ def condataset(
212218
):
213219
return Annotated[
214220
Dataset,
221+
BeforeValidator(Dataset.cast),
215222
AfterValidator(partial(_constrain_dtype, dtype_constraint=dtype_constraint)),
216223
AfterValidator(partial(_constraint_dim, min_dim=min_dim, max_dim=max_dim)),
217224
AfterValidator(partial(_constraint_shape, shape_constraint=shape_constraint)),
218225
]
219226

220227

228+
CastDataset = Annotated[Dataset, BeforeValidator(Dataset.cast)]
229+
230+
221231
########################################################################################
222232

223233

src/oqd_dataschema/groups.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515

16-
from oqd_dataschema.base import Dataset, GroupBase
16+
from oqd_dataschema.base import CastDataset, GroupBase
1717

1818
########################################################################################
1919

@@ -33,7 +33,7 @@ class SinaraRawDataGroup(GroupBase):
3333
This is a placeholder for demonstration and development.
3434
"""
3535

36-
camera_images: Dataset
36+
camera_images: CastDataset
3737

3838

3939
class MeasurementOutcomesDataGroup(GroupBase):
@@ -42,7 +42,7 @@ class MeasurementOutcomesDataGroup(GroupBase):
4242
This is a placeholder for demonstration and development.
4343
"""
4444

45-
outcomes: Dataset
45+
outcomes: CastDataset
4646

4747

4848
class ExpectationValueDataGroup(GroupBase):
@@ -51,11 +51,11 @@ class ExpectationValueDataGroup(GroupBase):
5151
This is a placeholder for demonstration and development.
5252
"""
5353

54-
expectation_value: Dataset
54+
expectation_value: CastDataset
5555

5656

5757
class OQDTestbenchDataGroup(GroupBase):
5858
""" """
5959

60-
time: Dataset
61-
voltages: Dataset
60+
time: CastDataset
61+
voltages: CastDataset

0 commit comments

Comments
 (0)