From de65b449dcc58e5beb149efc3a889147e87d3bb1 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Tue, 23 Sep 2025 12:45:38 -0400 Subject: [PATCH 01/48] [feat] Added Attrs with validation preventing special keys. Added attrs to Datastore. Added json schema of groups to be stored in hdf5. Changed Datastore groups validator to a field validator. --- src/oqd_dataschema/base.py | 22 ++++++++++++++++++++-- src/oqd_dataschema/datastore.py | 26 ++++++++++++++++++-------- 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/src/oqd_dataschema/base.py b/src/oqd_dataschema/base.py index 1139849..3cf7ec2 100644 --- a/src/oqd_dataschema/base.py +++ b/src/oqd_dataschema/base.py @@ -20,6 +20,7 @@ from bidict import bidict from pydantic import ( BaseModel, + BeforeValidator, ConfigDict, Discriminator, Field, @@ -34,6 +35,23 @@ ######################################################################################## +invalid_attrs = ["_model_signature", "_model_json"] + + +def _valid_attr_key(value): + if value in invalid_attrs: + raise KeyError + return value + + +Attrs = Optional[ + dict[ + Annotated[str, BeforeValidator(_valid_attr_key)], + Union[int, float, str, complex], + ] +] + + # %% mapping = bidict( { @@ -64,7 +82,7 @@ class GroupBase(BaseModel, extra="forbid"): ``` """ - attrs: Optional[dict[str, Union[int, float, str, complex]]] = {} + attrs: Attrs = {} def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) @@ -100,7 +118,7 @@ class Dataset(BaseModel, extra="forbid"): shape: Optional[tuple[int, ...]] = None data: Optional[Any] = Field(default=None, exclude=True) - attrs: Optional[dict[str, Union[int, float, str, complex]]] = {} + attrs: Attrs = {} model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True) diff --git a/src/oqd_dataschema/datastore.py b/src/oqd_dataschema/datastore.py index cf52c0c..d8a35d3 100644 --- a/src/oqd_dataschema/datastore.py +++ b/src/oqd_dataschema/datastore.py @@ -14,15 +14,19 @@ # %% +import json import pathlib from typing import Any, Dict, Literal, Optional import h5py import numpy as np -from pydantic import BaseModel, model_validator +from pydantic import ( + BaseModel, + field_validator, +) from pydantic.types import TypeVar -from oqd_dataschema.base import Dataset, GroupBase, GroupRegistry +from oqd_dataschema.base import Attrs, Dataset, GroupBase, GroupRegistry ######################################################################################## @@ -44,15 +48,17 @@ class Datastore(BaseModel, extra="forbid"): groups: Dict[str, Any] - @model_validator(mode="before") + attrs: Attrs = {} + + @field_validator("groups", mode="before") @classmethod def validate_groups(cls, data): - if isinstance(data, dict) and "groups" in data: + if isinstance(data, dict): # Get the current adapter from registry try: validated_groups = {} - for key, group_data in data["groups"].items(): + for key, group_data in data.items(): if isinstance(group_data, GroupBase): # Already a Group instance validated_groups[key] = group_data @@ -66,7 +72,7 @@ def validate_groups(cls, data): f"Invalid group data for key '{key}': {type(group_data)}" ) - data["groups"] = validated_groups + data = validated_groups except ValueError as e: if "No group types registered" in str(e): @@ -90,13 +96,17 @@ def model_dump_hdf5(self, filepath: pathlib.Path, mode: Literal["w", "a"] = "a") with h5py.File(filepath, mode) as f: # store the model JSON schema - f.attrs["model"] = self.model_dump_json() + f.attrs["_model_signature"] = self.model_dump_json() + for akey, attr in self.attrs.items(): + f.attrs[akey] = attr # store each group for gkey, group in self.groups.items(): if gkey in f.keys(): del f[gkey] h5_group = f.create_group(gkey) + + h5_group.attrs["_model_schema"] = json.dumps(group.model_json_schema()) for akey, attr in group.attrs.items(): h5_group.attrs[akey] = attr @@ -118,7 +128,7 @@ def model_validate_hdf5( filepath (pathlib.Path): The path to the HDF5 file where the model data will be read and validated from. """ with h5py.File(filepath, "r") as f: - self = cls.model_validate_json(f.attrs["model"]) + self = cls.model_validate_json(f.attrs["_model_signature"]) # loop through all groups in the model schema and load HDF5 store for gkey, group in self.groups.items(): From 47f6bce582ede8efbcb48cb91afb40749f035aab Mon Sep 17 00:00:00 2001 From: yhteoh Date: Tue, 23 Sep 2025 13:21:05 -0400 Subject: [PATCH 02/48] [feat] Added indexing into datastore indexes into groups --- src/oqd_dataschema/datastore.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/oqd_dataschema/datastore.py b/src/oqd_dataschema/datastore.py index d8a35d3..e2e77ec 100644 --- a/src/oqd_dataschema/datastore.py +++ b/src/oqd_dataschema/datastore.py @@ -131,9 +131,15 @@ def model_validate_hdf5( self = cls.model_validate_json(f.attrs["_model_signature"]) # loop through all groups in the model schema and load HDF5 store - for gkey, group in self.groups.items(): - for dkey, val in group.__dict__.items(): + for gkey, group in self: + for dkey in group.__class__.model_fields: if dkey in ("attrs", "class_"): continue group.__dict__[dkey].data = np.array(f[gkey][dkey][()]) return self + + def __getitem__(self, key): + return self.groups.__getitem__(key) + + def __iter__(self): + return self.groups.items().__iter__() From 0a5445c0959e6360085927dd593b7b6826a9f22b Mon Sep 17 00:00:00 2001 From: yhteoh Date: Tue, 23 Sep 2025 13:21:32 -0400 Subject: [PATCH 03/48] [feat] Strict field type requirements for subclasses of GroupBase --- src/oqd_dataschema/base.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/oqd_dataschema/base.py b/src/oqd_dataschema/base.py index 3cf7ec2..556356a 100644 --- a/src/oqd_dataschema/base.py +++ b/src/oqd_dataschema/base.py @@ -14,7 +14,7 @@ # %% import warnings -from typing import Annotated, Any, Literal, Optional, Union +from typing import Annotated, Any, ClassVar, Literal, Optional, Union import numpy as np from bidict import bidict @@ -86,6 +86,19 @@ class GroupBase(BaseModel, extra="forbid"): def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) + + for k, v in cls.__annotations__.items(): + if k == "class_": + raise AttributeError("`class_` attribute should not be set manually.") + + if k == "attrs" and k is not Attrs: + raise TypeError("`attrs` should be of type `Attrs`") + + if k not in ["class_", "attrs"] and v not in [Dataset, ClassVar]: + raise TypeError( + "All fields of `GroupBase` have to be of type `Dataset`." + ) + cls.__annotations__["class_"] = Literal[cls.__name__] setattr(cls, "class_", cls.__name__) From 6d66226ba52d93f0ec62947f98bd67580c17d36d Mon Sep 17 00:00:00 2001 From: yhteoh Date: Wed, 24 Sep 2025 09:14:47 -0400 Subject: [PATCH 04/48] [feat] Added str, bytes and bool support for Datasets --- src/oqd_dataschema/base.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/src/oqd_dataschema/base.py b/src/oqd_dataschema/base.py index 556356a..c38d65f 100644 --- a/src/oqd_dataschema/base.py +++ b/src/oqd_dataschema/base.py @@ -41,6 +41,7 @@ def _valid_attr_key(value): if value in invalid_attrs: raise KeyError + return value @@ -53,15 +54,19 @@ def _valid_attr_key(value): # %% -mapping = bidict( +dtype_map = bidict( { - "int32": np.dtype("int32"), - "int64": np.dtype("int64"), - "float32": np.dtype("float32"), - "float64": np.dtype("float64"), - "complex64": np.dtype("complex64"), - "complex128": np.dtype("complex128"), - # 'string': np.type + "int16": np.dtypes.Int16DType, + "int32": np.dtypes.Int32DType, + "int64": np.dtypes.Int64DType, + "float16": np.dtypes.Float16DType, + "float32": np.dtypes.Float32DType, + "float64": np.dtypes.Float64DType, + "complex64": np.dtypes.Complex64DType, + "complex128": np.dtypes.Complex128DType, + "string": np.dtypes.StrDType, + "bytes": np.dtypes.BytesDType, + "bool": np.dtypes.BoolDType, } ) @@ -127,7 +132,7 @@ class Dataset(BaseModel, extra="forbid"): ``` """ - dtype: Optional[Literal[tuple(mapping.keys())]] = None + dtype: Optional[Literal[tuple(dtype_map.keys())]] = None shape: Optional[tuple[int, ...]] = None data: Optional[Any] = Field(default=None, exclude=True) @@ -149,12 +154,12 @@ def validate_and_update(cls, values: dict): if not isinstance(data, np.ndarray): raise TypeError("`data` must be a numpy.ndarray.") - if data.dtype not in mapping.values(): + if type(data.dtype) not in dtype_map.values(): raise TypeError( - f"`data` must be a numpy array of dtype in {tuple(mapping.keys())}." + f"`data` must be a numpy array of dtype in {tuple(dtype_map.keys())}." ) - values["dtype"] = mapping.inverse[data.dtype] + values["dtype"] = dtype_map.inverse[type(data.dtype)] values["shape"] = data.shape return values @@ -163,8 +168,8 @@ def validate_and_update(cls, values: dict): def validate_data_matches_shape_dtype(self): """Ensure that `data` matches `dtype` and `shape`.""" if self.data is not None: - expected_dtype = mapping[self.dtype] - if self.data.dtype != expected_dtype: + expected_dtype = dtype_map[self.dtype] + if type(self.data.dtype) is not expected_dtype: raise ValueError( f"Expected data dtype `{self.dtype}`, but got `{self.data.dtype.name}`." ) From 0bc77901eba526fa40401cff66fd38656413d0ff Mon Sep 17 00:00:00 2001 From: yhteoh Date: Wed, 24 Sep 2025 09:36:05 -0400 Subject: [PATCH 05/48] [fix] serialization of str to hdf5. str are serialized as bytes during dump and casted back to str at validation of hdf5. --- src/oqd_dataschema/base.py | 2 +- src/oqd_dataschema/datastore.py | 13 ++++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/oqd_dataschema/base.py b/src/oqd_dataschema/base.py index c38d65f..25eef21 100644 --- a/src/oqd_dataschema/base.py +++ b/src/oqd_dataschema/base.py @@ -64,7 +64,7 @@ def _valid_attr_key(value): "float64": np.dtypes.Float64DType, "complex64": np.dtypes.Complex64DType, "complex128": np.dtypes.Complex128DType, - "string": np.dtypes.StrDType, + "str": np.dtypes.StrDType, "bytes": np.dtypes.BytesDType, "bool": np.dtypes.BoolDType, } diff --git a/src/oqd_dataschema/datastore.py b/src/oqd_dataschema/datastore.py index e2e77ec..55e7b43 100644 --- a/src/oqd_dataschema/datastore.py +++ b/src/oqd_dataschema/datastore.py @@ -26,7 +26,7 @@ ) from pydantic.types import TypeVar -from oqd_dataschema.base import Attrs, Dataset, GroupBase, GroupRegistry +from oqd_dataschema.base import Attrs, Dataset, GroupBase, GroupRegistry, dtype_map ######################################################################################## @@ -113,7 +113,12 @@ def model_dump_hdf5(self, filepath: pathlib.Path, mode: Literal["w", "a"] = "a") for dkey, dataset in group.__dict__.items(): if not isinstance(dataset, Dataset): continue - h5_dataset = h5_group.create_dataset(dkey, data=dataset.data) + h5_dataset = h5_group.create_dataset( + dkey, + data=dataset.data.astype(np.dtypes.BytesDType) + if dataset.dtype == "str" + else dataset.data, + ) for akey, attr in dataset.attrs.items(): h5_dataset.attrs[akey] = attr @@ -135,7 +140,9 @@ def model_validate_hdf5( for dkey in group.__class__.model_fields: if dkey in ("attrs", "class_"): continue - group.__dict__[dkey].data = np.array(f[gkey][dkey][()]) + group.__dict__[dkey].data = np.array(f[gkey][dkey][()]).astype( + dtype_map[group.__dict__[dkey].dtype] + ) return self def __getitem__(self, key): From 850e8721d61a1232174e69d5d3fda7c7549c0299 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Wed, 24 Sep 2025 13:46:07 -0400 Subject: [PATCH 06/48] [feat, refactor] refactor dtype map to use Enum DTypes. Implemented constrained dataset. Datastore.model_dump_hdf5 default mode changed to "w". --- examples/custom_group.ipynb | 23 ++-- pyproject.toml | 13 +- src/oqd_dataschema/__init__.py | 3 +- src/oqd_dataschema/base.py | 225 ++++++++++++++++++++++---------- src/oqd_dataschema/datastore.py | 27 ++-- tests/test_datastore.py | 7 +- uv.lock | 59 ++++----- 7 files changed, 214 insertions(+), 143 deletions(-) diff --git a/examples/custom_group.ipynb b/examples/custom_group.ipynb index 7109567..d29cce2 100644 --- a/examples/custom_group.ipynb +++ b/examples/custom_group.ipynb @@ -11,7 +11,7 @@ "import numpy as np\n", "from rich.pretty import pprint\n", "\n", - "from oqd_dataschema.base import Dataset, GroupBase, GroupRegistry\n", + "from oqd_dataschema.base import Dataset, GroupBase, GroupRegistry, condataset\n", "from oqd_dataschema.datastore import Datastore\n", "from oqd_dataschema.groups import (\n", " SinaraRawDataGroup,\n", @@ -29,7 +29,7 @@ " Here we define a custom Group, which is automatically added at runtime to the GroupRegistry.\n", " \"\"\"\n", "\n", - " array: Dataset" + " array: condataset(shape_constraint=(None, 10)) # type: ignore" ] }, { @@ -119,7 +119,8 @@ "│ │ │ ),\n", "│ │ │ class_='YourCustomGroup'\n", "│ │ )\n", - "}\n", + "},\n", + "attrs={}\n", ")\n", "\n" ], @@ -164,7 +165,8 @@ "\u001b[2;32m│ │ │ \u001b[0m\u001b[1m)\u001b[0m,\n", "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mclass_\u001b[0m=\u001b[32m'YourCustomGroup'\u001b[0m\n", "\u001b[2;32m│ │ \u001b[0m\u001b[1m)\u001b[0m\n", - "\u001b[2;32m│ \u001b[0m\u001b[1m}\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[33mattrs\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", "\u001b[1m)\u001b[0m\n" ] }, @@ -230,7 +232,8 @@ "│ │ │ ),\n", "│ │ │ class_='YourCustomGroup'\n", "│ │ )\n", - "}\n", + "},\n", + "attrs={}\n", ")\n", "\n" ], @@ -275,7 +278,8 @@ "\u001b[2;32m│ │ │ \u001b[0m\u001b[1m)\u001b[0m,\n", "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mclass_\u001b[0m=\u001b[32m'YourCustomGroup'\u001b[0m\n", "\u001b[2;32m│ │ \u001b[0m\u001b[1m)\u001b[0m\n", - "\u001b[2;32m│ \u001b[0m\u001b[1m}\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[33mattrs\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", "\u001b[1m)\u001b[0m\n" ] }, @@ -287,13 +291,6 @@ "parse = Datastore.model_validate_hdf5(filepath)\n", "pprint(parse)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/pyproject.toml b/pyproject.toml index 47a7ccc..7ccc3bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,11 +26,7 @@ classifiers = [ "Programming Language :: Python :: 3.12", ] -dependencies = [ - "bidict>=0.23.1", - "h5py>=3.13.0", - "pydantic>=2.10.6", -] +dependencies = ["h5py>=3.14.0", "pydantic>=2.10.6"] [project.optional-dependencies] docs = [ @@ -52,12 +48,7 @@ select = ["E4", "E7", "E9", "F", "I"] fixable = ["ALL"] [dependency-groups] -dev = [ - "jupyter>=1.1.1", - "pre-commit>=4.1.0", - "rich>=14.1.0", - "ruff>=0.13.1", -] +dev = ["jupyter>=1.1.1", "pre-commit>=4.1.0", "rich>=14.1.0", "ruff>=0.13.1"] [project.urls] diff --git a/src/oqd_dataschema/__init__.py b/src/oqd_dataschema/__init__.py index 38c732a..3238ec0 100644 --- a/src/oqd_dataschema/__init__.py +++ b/src/oqd_dataschema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import Dataset, GroupBase, GroupRegistry +from .base import Dataset, GroupBase, GroupRegistry, condataset from .datastore import Datastore from .groups import ( ExpectationValueDataGroup, @@ -32,4 +32,5 @@ "MeasurementOutcomesDataGroup", "OQDTestbenchDataGroup", "SinaraRawDataGroup", + "condataset", ] diff --git a/src/oqd_dataschema/base.py b/src/oqd_dataschema/base.py index 25eef21..59c3ae8 100644 --- a/src/oqd_dataschema/base.py +++ b/src/oqd_dataschema/base.py @@ -13,12 +13,15 @@ # limitations under the License. # %% +import typing import warnings -from typing import Annotated, Any, ClassVar, Literal, Optional, Union +from enum import Enum +from functools import partial, reduce +from typing import Annotated, Any, ClassVar, Literal, Optional, Sequence, Tuple, Union import numpy as np -from bidict import bidict from pydantic import ( + AfterValidator, BaseModel, BeforeValidator, ConfigDict, @@ -30,11 +33,39 @@ ######################################################################################## -__all__ = ["GroupBase", "Dataset", "GroupRegistry"] +__all__ = ["GroupBase", "Dataset", "GroupRegistry", "condataset"] ######################################################################################## +class DTypes(Enum): + BOOL = np.dtypes.BoolDType + INT16 = np.dtypes.Int16DType + INT32 = np.dtypes.Int32DType + INT64 = np.dtypes.Int64DType + UINT16 = np.dtypes.UInt16DType + UINT32 = np.dtypes.UInt32DType + UINT64 = np.dtypes.UInt64DType + FLOAT16 = np.dtypes.Float16DType + FLOAT32 = np.dtypes.Float32DType + FLOAT64 = np.dtypes.Float64DType + COMPLEX64 = np.dtypes.Complex64DType + COMPLEX128 = np.dtypes.Complex128DType + STR = np.dtypes.StrDType + BYTES = np.dtypes.BytesDType + STRING = np.dtypes.StringDType + + @classmethod + def get(cls, name): + return cls[name.upper()] + + @classmethod + def names(cls): + return tuple((dtype.name.lower() for dtype in cls)) + + +######################################################################################## + invalid_attrs = ["_model_signature", "_model_json"] @@ -52,63 +83,7 @@ def _valid_attr_key(value): ] ] - -# %% -dtype_map = bidict( - { - "int16": np.dtypes.Int16DType, - "int32": np.dtypes.Int32DType, - "int64": np.dtypes.Int64DType, - "float16": np.dtypes.Float16DType, - "float32": np.dtypes.Float32DType, - "float64": np.dtypes.Float64DType, - "complex64": np.dtypes.Complex64DType, - "complex128": np.dtypes.Complex128DType, - "str": np.dtypes.StrDType, - "bytes": np.dtypes.BytesDType, - "bool": np.dtypes.BoolDType, - } -) - - -class GroupBase(BaseModel, extra="forbid"): - """ - Schema representation for a group object within an HDF5 file. - - Each grouping of data should be defined as a subclass of `Group`, and specify the datasets that it will contain. - This base object only has attributes, `attrs`, which are associated to the HDF5 group. - - Attributes: - attrs: A dictionary of attributes to append to the dataset. - - Example: - ``` - group = Group(attrs={'version': 2, 'date': '2025-01-01'}) - ``` - """ - - attrs: Attrs = {} - - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - - for k, v in cls.__annotations__.items(): - if k == "class_": - raise AttributeError("`class_` attribute should not be set manually.") - - if k == "attrs" and k is not Attrs: - raise TypeError("`attrs` should be of type `Attrs`") - - if k not in ["class_", "attrs"] and v not in [Dataset, ClassVar]: - raise TypeError( - "All fields of `GroupBase` have to be of type `Dataset`." - ) - - cls.__annotations__["class_"] = Literal[cls.__name__] - setattr(cls, "class_", cls.__name__) - - # Auto-register new group types - GroupRegistry.register(cls) +######################################################################################## class Dataset(BaseModel, extra="forbid"): @@ -132,13 +107,15 @@ class Dataset(BaseModel, extra="forbid"): ``` """ - dtype: Optional[Literal[tuple(dtype_map.keys())]] = None - shape: Optional[tuple[int, ...]] = None + dtype: Optional[Literal[DTypes.names()]] = None # type: ignore + shape: Optional[Tuple[int, ...]] = None data: Optional[Any] = Field(default=None, exclude=True) attrs: Attrs = {} - model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True) + model_config = ConfigDict( + use_enum_values=False, arbitrary_types_allowed=True, validate_assignment=True + ) @model_validator(mode="before") @classmethod @@ -154,12 +131,12 @@ def validate_and_update(cls, values: dict): if not isinstance(data, np.ndarray): raise TypeError("`data` must be a numpy.ndarray.") - if type(data.dtype) not in dtype_map.values(): + if type(data.dtype) not in DTypes: raise TypeError( - f"`data` must be a numpy array of dtype in {tuple(dtype_map.keys())}." + f"`data` must be a numpy array of dtype in {tuple(DTypes.names())}." ) - values["dtype"] = dtype_map.inverse[type(data.dtype)] + values["dtype"] = DTypes(type(data.dtype)).name.lower() values["shape"] = data.shape return values @@ -168,7 +145,7 @@ def validate_and_update(cls, values: dict): def validate_data_matches_shape_dtype(self): """Ensure that `data` matches `dtype` and `shape`.""" if self.data is not None: - expected_dtype = dtype_map[self.dtype] + expected_dtype = DTypes.get(self.dtype).value if type(self.data.dtype) is not expected_dtype: raise ValueError( f"Expected data dtype `{self.dtype}`, but got `{self.data.dtype.name}`." @@ -179,6 +156,117 @@ def validate_data_matches_shape_dtype(self): ) return self + def __getitem__(self, idx): + return self.data[idx] + + +def _constrain_dtype(dataset, *, dtype_constraint=None): + if (not isinstance(dtype_constraint, str)) and isinstance( + dtype_constraint, Sequence + ): + dtype_constraint = set(dtype_constraint) + elif isinstance(dtype_constraint, str): + dtype_constraint = {dtype_constraint} + + if dtype_constraint and dataset.dtype not in dtype_constraint: + raise ValueError( + f"Expected dtype to be of type one of {dtype_constraint}, but got {dataset.dtype}." + ) + + return dataset + + +def _constraint_dim(dataset, *, min_dim=None, max_dim=None): + min_dim = 0 if min_dim is None else min_dim + + dims = len(dataset.shape) + + if dims < min_dim or (max_dim is not None and dims > max_dim): + raise ValueError( + f"Expected {min_dim} <= dimension of shape{f' <= {max_dim}'}, but got shape = {dataset.shape}." + ) + + return dataset + + +def _constraint_shape(dataset, *, shape_constraint=None): + if shape_constraint and ( + len(shape_constraint) != len(dataset.shape) + or reduce( + lambda x, y: x or y, + map( + lambda x: x[0] is not None and x[0] != x[1], + zip(shape_constraint, dataset.shape), + ), + ) + ): + raise ValueError( + f"Expected shape to be {shape_constraint}, but got {dataset.shape}." + ) + + return dataset + + +def condataset( + *, shape_constraint=None, dtype_constraint=None, min_dim=None, max_dim=None +): + return Annotated[ + Dataset, + AfterValidator(partial(_constrain_dtype, dtype_constraint=dtype_constraint)), + AfterValidator(partial(_constraint_dim, min_dim=min_dim, max_dim=max_dim)), + AfterValidator(partial(_constraint_shape, shape_constraint=shape_constraint)), + ] + + +######################################################################################## + + +class GroupBase(BaseModel, extra="forbid"): + """ + Schema representation for a group object within an HDF5 file. + + Each grouping of data should be defined as a subclass of `Group`, and specify the datasets that it will contain. + This base object only has attributes, `attrs`, which are associated to the HDF5 group. + + Attributes: + attrs: A dictionary of attributes to append to the dataset. + + Example: + ``` + group = Group(attrs={'version': 2, 'date': '2025-01-01'}) + ``` + """ + + attrs: Attrs = {} + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + + for k, v in cls.__annotations__.items(): + if k == "class_": + raise AttributeError("`class_` attribute should not be set manually.") + + if k == "attrs" and k is not Attrs: + raise TypeError("`attrs` should be of type `Attrs`") + + if ( + k not in ["class_", "attrs"] + and v not in [Dataset, ClassVar] + and not (typing.get_origin(v) == Annotated and v.__origin__ is Dataset) + ): + raise TypeError( + "All fields of `GroupBase` have to be of type `Dataset`." + ) + + cls.__annotations__["class_"] = Literal[cls.__name__] + setattr(cls, "class_", cls.__name__) + + # Auto-register new group types + GroupRegistry.register(cls) + + +######################################################################################## + class MetaGroupRegistry(type): def __new__(cls, clsname, superclasses, attributedict): @@ -217,3 +305,6 @@ def adapter(cls): class GroupRegistry(metaclass=MetaGroupRegistry): pass + + +# %% diff --git a/src/oqd_dataschema/datastore.py b/src/oqd_dataschema/datastore.py index 55e7b43..6638223 100644 --- a/src/oqd_dataschema/datastore.py +++ b/src/oqd_dataschema/datastore.py @@ -16,7 +16,7 @@ import json import pathlib -from typing import Any, Dict, Literal, Optional +from typing import Any, Dict, Literal import h5py import numpy as np @@ -24,9 +24,8 @@ BaseModel, field_validator, ) -from pydantic.types import TypeVar -from oqd_dataschema.base import Attrs, Dataset, GroupBase, GroupRegistry, dtype_map +from oqd_dataschema.base import Attrs, Dataset, DTypes, GroupBase, GroupRegistry ######################################################################################## @@ -83,7 +82,7 @@ def validate_groups(cls, data): return data - def model_dump_hdf5(self, filepath: pathlib.Path, mode: Literal["w", "a"] = "a"): + def model_dump_hdf5(self, filepath: pathlib.Path, mode: Literal["w", "a"] = "w"): """ Saves the model and its associated data to an HDF5 file. This method serializes the model's data and attributes into an HDF5 file @@ -113,19 +112,19 @@ def model_dump_hdf5(self, filepath: pathlib.Path, mode: Literal["w", "a"] = "a") for dkey, dataset in group.__dict__.items(): if not isinstance(dataset, Dataset): continue - h5_dataset = h5_group.create_dataset( - dkey, - data=dataset.data.astype(np.dtypes.BytesDType) - if dataset.dtype == "str" - else dataset.data, - ) + + if dataset.dtype in "str": + h5_dataset = h5_group.create_dataset( + dkey, data=dataset.data.astype(np.dtypes.BytesDType) + ) + else: + h5_dataset = h5_group.create_dataset(dkey, data=dataset.data) + for akey, attr in dataset.attrs.items(): h5_dataset.attrs[akey] = attr @classmethod - def model_validate_hdf5( - cls, filepath: pathlib.Path, types: Optional[TypeVar] = None - ): + def model_validate_hdf5(cls, filepath: pathlib.Path): """ Loads the model from an HDF5 file at the specified filepath. @@ -141,7 +140,7 @@ def model_validate_hdf5( if dkey in ("attrs", "class_"): continue group.__dict__[dkey].data = np.array(f[gkey][dkey][()]).astype( - dtype_map[group.__dict__[dkey].dtype] + DTypes.get(group.__dict__[dkey].dtype).value ) return self diff --git a/tests/test_datastore.py b/tests/test_datastore.py index 6970b07..384d3fe 100644 --- a/tests/test_datastore.py +++ b/tests/test_datastore.py @@ -18,7 +18,7 @@ import numpy as np import pytest -from oqd_dataschema.base import Dataset, mapping +from oqd_dataschema.base import Dataset, DTypes from oqd_dataschema.datastore import Datastore from oqd_dataschema.groups import ( SinaraRawDataGroup, @@ -47,7 +47,10 @@ def test_serialize_deserialize(dtype): data_reload = Datastore.model_validate_hdf5(filepath) - assert data_reload.groups["test"].camera_images.data.dtype == mapping[dtype] + assert ( + type(data_reload.groups["test"].camera_images.data.dtype) + is DTypes.get(dtype).value + ) # %% diff --git a/uv.lock b/uv.lock index 22accdc..f651324 100644 --- a/uv.lock +++ b/uv.lock @@ -166,15 +166,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/eb/f4151e0c7377a6e08a38108609ba5cede57986802757848688aeedd1b9e8/beautifulsoup4-4.13.5-py3-none-any.whl", hash = "sha256:642085eaa22233aceadff9c69651bc51e8bf3f874fb6d7104ece2beb24b47c4a", size = 105113 }, ] -[[package]] -name = "bidict" -version = "0.23.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/9a/6e/026678aa5a830e07cd9498a05d3e7e650a4f56a42f267a53d22bcda1bdc9/bidict-0.23.1.tar.gz", hash = "sha256:03069d763bc387bbd20e7d49914e75fc4132a41937fa3405417e1a5a2d006d71", size = 29093 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/99/37/e8730c3587a65eb5645d4aba2d27aae48e8003614d6aaf15dda67f702f1f/bidict-0.23.1-py3-none-any.whl", hash = "sha256:5dae8d4d79b552a71cbabc7deb25dfe8ce710b17ff41711e13010ead2abfc3e5", size = 32764 }, -] - [[package]] name = "bleach" version = "6.2.0" @@ -525,33 +516,33 @@ wheels = [ [[package]] name = "h5py" -version = "3.13.0" +version = "3.14.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/03/2e/a22d6a8bfa6f8be33e7febd985680fba531562795f0a9077ed1eb047bfb0/h5py-3.13.0.tar.gz", hash = "sha256:1870e46518720023da85d0895a1960ff2ce398c5671eac3b1a41ec696b7105c3", size = 414876 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/02/8a/bc76588ff1a254e939ce48f30655a8f79fac614ca8bd1eda1a79fa276671/h5py-3.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:5540daee2b236d9569c950b417f13fd112d51d78b4c43012de05774908dff3f5", size = 3413286 }, - { url = "https://files.pythonhosted.org/packages/19/bd/9f249ecc6c517b2796330b0aab7d2351a108fdbd00d4bb847c0877b5533e/h5py-3.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:10894c55d46df502d82a7a4ed38f9c3fdbcb93efb42e25d275193e093071fade", size = 2915673 }, - { url = "https://files.pythonhosted.org/packages/72/71/0dd079208d7d3c3988cebc0776c2de58b4d51d8eeb6eab871330133dfee6/h5py-3.13.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb267ce4b83f9c42560e9ff4d30f60f7ae492eacf9c7ede849edf8c1b860e16b", size = 4283822 }, - { url = "https://files.pythonhosted.org/packages/d8/fa/0b6a59a1043c53d5d287effa02303bd248905ee82b25143c7caad8b340ad/h5py-3.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2cf6a231a07c14acd504a945a6e9ec115e0007f675bde5e0de30a4dc8d86a31", size = 4548100 }, - { url = "https://files.pythonhosted.org/packages/12/42/ad555a7ff7836c943fe97009405566dc77bcd2a17816227c10bd067a3ee1/h5py-3.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:851ae3a8563d87a5a0dc49c2e2529c75b8842582ccaefbf84297d2cfceeacd61", size = 2950547 }, - { url = "https://files.pythonhosted.org/packages/86/2b/50b15fdefb577d073b49699e6ea6a0a77a3a1016c2b67e2149fc50124a10/h5py-3.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:8a8e38ef4ceb969f832cc230c0cf808c613cc47e31e768fd7b1106c55afa1cb8", size = 3422922 }, - { url = "https://files.pythonhosted.org/packages/94/59/36d87a559cab9c59b59088d52e86008d27a9602ce3afc9d3b51823014bf3/h5py-3.13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f35640e81b03c02a88b8bf99fb6a9d3023cc52f7c627694db2f379e0028f2868", size = 2921619 }, - { url = "https://files.pythonhosted.org/packages/37/ef/6f80b19682c0b0835bbee7b253bec9c16af9004f2fd6427b1dd858100273/h5py-3.13.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:337af114616f3656da0c83b68fcf53ecd9ce9989a700b0883a6e7c483c3235d4", size = 4259366 }, - { url = "https://files.pythonhosted.org/packages/03/71/c99f662d4832c8835453cf3476f95daa28372023bda4aa1fca9e97c24f09/h5py-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:782ff0ac39f455f21fd1c8ebc007328f65f43d56718a89327eec76677ebf238a", size = 4509058 }, - { url = "https://files.pythonhosted.org/packages/56/89/e3ff23e07131ff73a72a349be9639e4de84e163af89c1c218b939459a98a/h5py-3.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:22ffe2a25770a2d67213a1b94f58006c14dce06933a42d2aaa0318c5868d1508", size = 2966428 }, - { url = "https://files.pythonhosted.org/packages/d8/20/438f6366ba4ded80eadb38f8927f5e2cd6d2e087179552f20ae3dbcd5d5b/h5py-3.13.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:477c58307b6b9a2509c59c57811afb9f598aedede24a67da808262dfa0ee37b4", size = 3384442 }, - { url = "https://files.pythonhosted.org/packages/10/13/cc1cb7231399617d9951233eb12fddd396ff5d4f7f057ee5d2b1ca0ee7e7/h5py-3.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:57c4c74f627c616f02b7aec608a8c706fe08cb5b0ba7c08555a4eb1dde20805a", size = 2917567 }, - { url = "https://files.pythonhosted.org/packages/9e/d9/aed99e1c858dc698489f916eeb7c07513bc864885d28ab3689d572ba0ea0/h5py-3.13.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:357e6dc20b101a805ccfd0024731fbaf6e8718c18c09baf3b5e4e9d198d13fca", size = 4669544 }, - { url = "https://files.pythonhosted.org/packages/a7/da/3c137006ff5f0433f0fb076b1ebe4a7bf7b5ee1e8811b5486af98b500dd5/h5py-3.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d6f13f9b5ce549448c01e4dfe08ea8d1772e6078799af2c1c8d09e941230a90d", size = 4932139 }, - { url = "https://files.pythonhosted.org/packages/25/61/d897952629cae131c19d4c41b2521e7dd6382f2d7177c87615c2e6dced1a/h5py-3.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:21daf38171753899b5905f3d82c99b0b1ec2cbbe282a037cad431feb620e62ec", size = 2954179 }, - { url = "https://files.pythonhosted.org/packages/60/43/f276f27921919a9144074320ce4ca40882fc67b3cfee81c3f5c7df083e97/h5py-3.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e520ec76de00943dd017c8ea3f354fa1d2f542eac994811943a8faedf2a7d5cb", size = 3358040 }, - { url = "https://files.pythonhosted.org/packages/1b/86/ad4a4cf781b08d4572be8bbdd8f108bb97b266a14835c640dc43dafc0729/h5py-3.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e79d8368cd9295045956bfb436656bea3f915beaa11d342e9f79f129f5178763", size = 2892766 }, - { url = "https://files.pythonhosted.org/packages/69/84/4c6367d6b58deaf0fa84999ec819e7578eee96cea6cbd613640d0625ed5e/h5py-3.13.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:56dd172d862e850823c4af02dc4ddbc308f042b85472ffdaca67f1598dff4a57", size = 4664255 }, - { url = "https://files.pythonhosted.org/packages/fd/41/bc2df86b72965775f6d621e0ee269a5f3ac23e8f870abf519de9c7d93b4d/h5py-3.13.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:be949b46b7388074c5acae017fbbe3e5ba303fd9daaa52157fdfef30bbdacadd", size = 4927580 }, - { url = "https://files.pythonhosted.org/packages/97/34/165b87ea55184770a0c1fcdb7e017199974ad2e271451fd045cfe35f3add/h5py-3.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:4f97ecde7ac6513b21cd95efdfc38dc6d19f96f6ca6f2a30550e94e551458e0a", size = 2940890 }, +sdist = { url = "https://files.pythonhosted.org/packages/5d/57/dfb3c5c3f1bf5f5ef2e59a22dec4ff1f3d7408b55bfcefcfb0ea69ef21c6/h5py-3.14.0.tar.gz", hash = "sha256:2372116b2e0d5d3e5e705b7f663f7c8d96fa79a4052d250484ef91d24d6a08f4", size = 424323 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/89/06cbb421e01dea2e338b3154326523c05d9698f89a01f9d9b65e1ec3fb18/h5py-3.14.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:24df6b2622f426857bda88683b16630014588a0e4155cba44e872eb011c4eaed", size = 3332522 }, + { url = "https://files.pythonhosted.org/packages/c3/e7/6c860b002329e408348735bfd0459e7b12f712c83d357abeef3ef404eaa9/h5py-3.14.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6ff2389961ee5872de697054dd5a033b04284afc3fb52dc51d94561ece2c10c6", size = 2831051 }, + { url = "https://files.pythonhosted.org/packages/fa/cd/3dd38cdb7cc9266dc4d85f27f0261680cb62f553f1523167ad7454e32b11/h5py-3.14.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:016e89d3be4c44f8d5e115fab60548e518ecd9efe9fa5c5324505a90773e6f03", size = 4324677 }, + { url = "https://files.pythonhosted.org/packages/b1/45/e1a754dc7cd465ba35e438e28557119221ac89b20aaebef48282654e3dc7/h5py-3.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1223b902ef0b5d90bcc8a4778218d6d6cd0f5561861611eda59fa6c52b922f4d", size = 4557272 }, + { url = "https://files.pythonhosted.org/packages/5c/06/f9506c1531645829d302c420851b78bb717af808dde11212c113585fae42/h5py-3.14.0-cp310-cp310-win_amd64.whl", hash = "sha256:852b81f71df4bb9e27d407b43071d1da330d6a7094a588efa50ef02553fa7ce4", size = 2866734 }, + { url = "https://files.pythonhosted.org/packages/61/1b/ad24a8ce846cf0519695c10491e99969d9d203b9632c4fcd5004b1641c2e/h5py-3.14.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f30dbc58f2a0efeec6c8836c97f6c94afd769023f44e2bb0ed7b17a16ec46088", size = 3352382 }, + { url = "https://files.pythonhosted.org/packages/36/5b/a066e459ca48b47cc73a5c668e9924d9619da9e3c500d9fb9c29c03858ec/h5py-3.14.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:543877d7f3d8f8a9828ed5df6a0b78ca3d8846244b9702e99ed0d53610b583a8", size = 2852492 }, + { url = "https://files.pythonhosted.org/packages/08/0c/5e6aaf221557314bc15ba0e0da92e40b24af97ab162076c8ae009320a42b/h5py-3.14.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c497600c0496548810047257e36360ff551df8b59156d3a4181072eed47d8ad", size = 4298002 }, + { url = "https://files.pythonhosted.org/packages/21/d4/d461649cafd5137088fb7f8e78fdc6621bb0c4ff2c090a389f68e8edc136/h5py-3.14.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:723a40ee6505bd354bfd26385f2dae7bbfa87655f4e61bab175a49d72ebfc06b", size = 4516618 }, + { url = "https://files.pythonhosted.org/packages/db/0c/6c3f879a0f8e891625817637fad902da6e764e36919ed091dc77529004ac/h5py-3.14.0-cp311-cp311-win_amd64.whl", hash = "sha256:d2744b520440a996f2dae97f901caa8a953afc055db4673a993f2d87d7f38713", size = 2874888 }, + { url = "https://files.pythonhosted.org/packages/3e/77/8f651053c1843391e38a189ccf50df7e261ef8cd8bfd8baba0cbe694f7c3/h5py-3.14.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:e0045115d83272090b0717c555a31398c2c089b87d212ceba800d3dc5d952e23", size = 3312740 }, + { url = "https://files.pythonhosted.org/packages/ff/10/20436a6cf419b31124e59fefc78d74cb061ccb22213226a583928a65d715/h5py-3.14.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6da62509b7e1d71a7d110478aa25d245dd32c8d9a1daee9d2a42dba8717b047a", size = 2829207 }, + { url = "https://files.pythonhosted.org/packages/3f/19/c8bfe8543bfdd7ccfafd46d8cfd96fce53d6c33e9c7921f375530ee1d39a/h5py-3.14.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:554ef0ced3571366d4d383427c00c966c360e178b5fb5ee5bb31a435c424db0c", size = 4708455 }, + { url = "https://files.pythonhosted.org/packages/86/f9/f00de11c82c88bfc1ef22633557bfba9e271e0cb3189ad704183fc4a2644/h5py-3.14.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0cbd41f4e3761f150aa5b662df991868ca533872c95467216f2bec5fcad84882", size = 4929422 }, + { url = "https://files.pythonhosted.org/packages/7a/6d/6426d5d456f593c94b96fa942a9b3988ce4d65ebaf57d7273e452a7222e8/h5py-3.14.0-cp312-cp312-win_amd64.whl", hash = "sha256:bf4897d67e613ecf5bdfbdab39a1158a64df105827da70ea1d90243d796d367f", size = 2862845 }, + { url = "https://files.pythonhosted.org/packages/6c/c2/7efe82d09ca10afd77cd7c286e42342d520c049a8c43650194928bcc635c/h5py-3.14.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:aa4b7bbce683379b7bf80aaba68e17e23396100336a8d500206520052be2f812", size = 3289245 }, + { url = "https://files.pythonhosted.org/packages/4f/31/f570fab1239b0d9441024b92b6ad03bb414ffa69101a985e4c83d37608bd/h5py-3.14.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ef9603a501a04fcd0ba28dd8f0995303d26a77a980a1f9474b3417543d4c6174", size = 2807335 }, + { url = "https://files.pythonhosted.org/packages/0d/ce/3a21d87896bc7e3e9255e0ad5583ae31ae9e6b4b00e0bcb2a67e2b6acdbc/h5py-3.14.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8cbaf6910fa3983c46172666b0b8da7b7bd90d764399ca983236f2400436eeb", size = 4700675 }, + { url = "https://files.pythonhosted.org/packages/e7/ec/86f59025306dcc6deee5fda54d980d077075b8d9889aac80f158bd585f1b/h5py-3.14.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d90e6445ab7c146d7f7981b11895d70bc1dd91278a4f9f9028bc0c95e4a53f13", size = 4921632 }, + { url = "https://files.pythonhosted.org/packages/3f/6d/0084ed0b78d4fd3e7530c32491f2884140d9b06365dac8a08de726421d4a/h5py-3.14.0-cp313-cp313-win_amd64.whl", hash = "sha256:ae18e3de237a7a830adb76aaa68ad438d85fe6e19e0d99944a3ce46b772c69b3", size = 2852929 }, ] [[package]] @@ -1436,7 +1427,6 @@ name = "oqd-dataschema" version = "0.1.0" source = { editable = "." } dependencies = [ - { name = "bidict" }, { name = "h5py" }, { name = "pydantic" }, ] @@ -1463,8 +1453,7 @@ dev = [ [package.metadata] requires-dist = [ - { name = "bidict", specifier = ">=0.23.1" }, - { name = "h5py", specifier = ">=3.13.0" }, + { name = "h5py", specifier = ">=3.14.0" }, { name = "mdx-truly-sane-lists", marker = "extra == 'docs'" }, { name = "mkdocs-material", marker = "extra == 'docs'" }, { name = "mkdocstrings", marker = "extra == 'docs'" }, From f5a92760927a2f2a5dac2dc42b58fd7ec5d41110 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Wed, 24 Sep 2025 14:02:33 -0400 Subject: [PATCH 07/48] [feat] Implemented CastDataset that cast a numpy array automatically to a Dataset --- src/oqd_dataschema/__init__.py | 1 + src/oqd_dataschema/base.py | 12 +++++++++++- src/oqd_dataschema/groups.py | 12 ++++++------ 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/oqd_dataschema/__init__.py b/src/oqd_dataschema/__init__.py index 3238ec0..934d47d 100644 --- a/src/oqd_dataschema/__init__.py +++ b/src/oqd_dataschema/__init__.py @@ -33,4 +33,5 @@ "OQDTestbenchDataGroup", "SinaraRawDataGroup", "condataset", + "CastDataset", ] diff --git a/src/oqd_dataschema/base.py b/src/oqd_dataschema/base.py index 59c3ae8..e67a61c 100644 --- a/src/oqd_dataschema/base.py +++ b/src/oqd_dataschema/base.py @@ -33,7 +33,7 @@ ######################################################################################## -__all__ = ["GroupBase", "Dataset", "GroupRegistry", "condataset"] +__all__ = ["GroupBase", "Dataset", "GroupRegistry", "condataset", "CastDataset"] ######################################################################################## @@ -117,6 +117,12 @@ class Dataset(BaseModel, extra="forbid"): use_enum_values=False, arbitrary_types_allowed=True, validate_assignment=True ) + @classmethod + def cast(cls, data): + if isinstance(data, np.ndarray): + return cls(data=data) + return data + @model_validator(mode="before") @classmethod def validate_and_update(cls, values: dict): @@ -212,12 +218,16 @@ def condataset( ): return Annotated[ Dataset, + BeforeValidator(Dataset.cast), AfterValidator(partial(_constrain_dtype, dtype_constraint=dtype_constraint)), AfterValidator(partial(_constraint_dim, min_dim=min_dim, max_dim=max_dim)), AfterValidator(partial(_constraint_shape, shape_constraint=shape_constraint)), ] +CastDataset = Annotated[Dataset, BeforeValidator(Dataset.cast)] + + ######################################################################################## diff --git a/src/oqd_dataschema/groups.py b/src/oqd_dataschema/groups.py index 88ecd2f..b244f62 100644 --- a/src/oqd_dataschema/groups.py +++ b/src/oqd_dataschema/groups.py @@ -13,7 +13,7 @@ # limitations under the License. -from oqd_dataschema.base import Dataset, GroupBase +from oqd_dataschema.base import CastDataset, GroupBase ######################################################################################## @@ -33,7 +33,7 @@ class SinaraRawDataGroup(GroupBase): This is a placeholder for demonstration and development. """ - camera_images: Dataset + camera_images: CastDataset class MeasurementOutcomesDataGroup(GroupBase): @@ -42,7 +42,7 @@ class MeasurementOutcomesDataGroup(GroupBase): This is a placeholder for demonstration and development. """ - outcomes: Dataset + outcomes: CastDataset class ExpectationValueDataGroup(GroupBase): @@ -51,11 +51,11 @@ class ExpectationValueDataGroup(GroupBase): This is a placeholder for demonstration and development. """ - expectation_value: Dataset + expectation_value: CastDataset class OQDTestbenchDataGroup(GroupBase): """ """ - time: Dataset - voltages: Dataset + time: CastDataset + voltages: CastDataset From ca8920cee6152d25e074b33383a79cc35c54ba6d Mon Sep 17 00:00:00 2001 From: yhteoh Date: Wed, 24 Sep 2025 14:03:42 -0400 Subject: [PATCH 08/48] [rename] renamed protected attrs in hdf5 --- src/oqd_dataschema/base.py | 2 +- src/oqd_dataschema/datastore.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/oqd_dataschema/base.py b/src/oqd_dataschema/base.py index e67a61c..d18b689 100644 --- a/src/oqd_dataschema/base.py +++ b/src/oqd_dataschema/base.py @@ -66,7 +66,7 @@ def names(cls): ######################################################################################## -invalid_attrs = ["_model_signature", "_model_json"] +invalid_attrs = ["_datastore_signature", "_group_json"] def _valid_attr_key(value): diff --git a/src/oqd_dataschema/datastore.py b/src/oqd_dataschema/datastore.py index 6638223..202736e 100644 --- a/src/oqd_dataschema/datastore.py +++ b/src/oqd_dataschema/datastore.py @@ -95,7 +95,7 @@ def model_dump_hdf5(self, filepath: pathlib.Path, mode: Literal["w", "a"] = "w") with h5py.File(filepath, mode) as f: # store the model JSON schema - f.attrs["_model_signature"] = self.model_dump_json() + f.attrs["_datastore_signature"] = self.model_dump_json(indent=2) for akey, attr in self.attrs.items(): f.attrs[akey] = attr @@ -105,7 +105,9 @@ def model_dump_hdf5(self, filepath: pathlib.Path, mode: Literal["w", "a"] = "w") del f[gkey] h5_group = f.create_group(gkey) - h5_group.attrs["_model_schema"] = json.dumps(group.model_json_schema()) + h5_group.attrs["_group_schema"] = json.dumps( + group.model_json_schema(), indent=2 + ) for akey, attr in group.attrs.items(): h5_group.attrs[akey] = attr @@ -132,7 +134,7 @@ def model_validate_hdf5(cls, filepath: pathlib.Path): filepath (pathlib.Path): The path to the HDF5 file where the model data will be read and validated from. """ with h5py.File(filepath, "r") as f: - self = cls.model_validate_json(f.attrs["_model_signature"]) + self = cls.model_validate_json(f.attrs["_datastore_signature"]) # loop through all groups in the model schema and load HDF5 store for gkey, group in self: From a5c7ad36a04fb149c599f63e14c3a97c8f0abc8a Mon Sep 17 00:00:00 2001 From: yhteoh Date: Wed, 24 Sep 2025 14:43:01 -0400 Subject: [PATCH 09/48] [feat] Allow dictionary of datasets as a field in a Group --- examples/custom_group.ipynb | 145 ++++++++++++++++++++++++++++++++ src/oqd_dataschema/base.py | 16 +++- src/oqd_dataschema/datastore.py | 72 ++++++++++------ 3 files changed, 205 insertions(+), 28 deletions(-) diff --git a/examples/custom_group.ipynb b/examples/custom_group.ipynb index d29cce2..c632665 100644 --- a/examples/custom_group.ipynb +++ b/examples/custom_group.ipynb @@ -291,6 +291,151 @@ "parse = Datastore.model_validate_hdf5(filepath)\n", "pprint(parse)" ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Dict\n", + "\n", + "from oqd_dataschema.base import CastDataset\n", + "\n", + "\n", + "class A(GroupBase):\n", + " data: Dict[str, CastDataset]" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
Datastore(\n",
+       "groups={\n",
+       "│   │   'A': A(\n",
+       "│   │   │   attrs={},\n",
+       "│   │   │   data={\n",
+       "│   │   │   │   'x': Dataset(\n",
+       "│   │   │   │   │   dtype='float64',\n",
+       "│   │   │   │   │   shape=(10,),\n",
+       "│   │   │   │   │   data=array([0.90326782, 0.17363226, 0.13827196, 0.8917397 , 0.68175954,\n",
+       "0.47647195, 0.88443397, 0.75703312, 0.74991232, 0.68161151]),\n",
+       "│   │   │   │   │   attrs={'type': 'mytype'}\n",
+       "│   │   │   │   )\n",
+       "│   │   │   },\n",
+       "│   │   │   class_='A'\n",
+       "│   │   )\n",
+       "},\n",
+       "attrs={}\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mDatastore\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[33mgroups\u001b[0m=\u001b[1m{\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[32m'A'\u001b[0m: \u001b[1;35mA\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mattrs\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mdata\u001b[0m=\u001b[1m{\u001b[0m\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[32m'x'\u001b[0m: \u001b[1;35mDataset\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mdtype\u001b[0m=\u001b[32m'float64'\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m10\u001b[0m,\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mdata\u001b[0m=\u001b[1;35marray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.90326782\u001b[0m, \u001b[1;36m0.17363226\u001b[0m, \u001b[1;36m0.13827196\u001b[0m, \u001b[1;36m0.8917397\u001b[0m , \u001b[1;36m0.68175954\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[1;36m0.47647195\u001b[0m, \u001b[1;36m0.88443397\u001b[0m, \u001b[1;36m0.75703312\u001b[0m, \u001b[1;36m0.74991232\u001b[0m, \u001b[1;36m0.68161151\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mattrs\u001b[0m=\u001b[1m{\u001b[0m\u001b[32m'type'\u001b[0m: \u001b[32m'mytype'\u001b[0m\u001b[1m}\u001b[0m\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mclass_\u001b[0m=\u001b[32m'A'\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[33mattrs\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "filepath = pathlib.Path(\"test.h5\")\n", + "\n", + "datastore = Datastore(\n", + " groups={\n", + " \"A\": A(data={\"x\": Dataset(data=np.random.rand(10), attrs={\"type\": \"mytype\"})})\n", + " }\n", + ")\n", + "pprint(datastore)\n", + "datastore.model_dump_hdf5(filepath)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
Datastore(\n",
+       "groups={\n",
+       "│   │   'A': A(\n",
+       "│   │   │   attrs={},\n",
+       "│   │   │   data={\n",
+       "│   │   │   │   'x': Dataset(\n",
+       "│   │   │   │   │   dtype='float64',\n",
+       "│   │   │   │   │   shape=(10,),\n",
+       "│   │   │   │   │   data=array([0.90326782, 0.17363226, 0.13827196, 0.8917397 , 0.68175954,\n",
+       "0.47647195, 0.88443397, 0.75703312, 0.74991232, 0.68161151]),\n",
+       "│   │   │   │   │   attrs={'type': 'mytype'}\n",
+       "│   │   │   │   )\n",
+       "│   │   │   },\n",
+       "│   │   │   class_='A'\n",
+       "│   │   )\n",
+       "},\n",
+       "attrs={}\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mDatastore\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[33mgroups\u001b[0m=\u001b[1m{\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[32m'A'\u001b[0m: \u001b[1;35mA\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mattrs\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mdata\u001b[0m=\u001b[1m{\u001b[0m\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[32m'x'\u001b[0m: \u001b[1;35mDataset\u001b[0m\u001b[1m(\u001b[0m\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mdtype\u001b[0m=\u001b[32m'float64'\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mshape\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m10\u001b[0m,\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mdata\u001b[0m=\u001b[1;35marray\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1;36m0.90326782\u001b[0m, \u001b[1;36m0.17363226\u001b[0m, \u001b[1;36m0.13827196\u001b[0m, \u001b[1;36m0.8917397\u001b[0m , \u001b[1;36m0.68175954\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[1;36m0.47647195\u001b[0m, \u001b[1;36m0.88443397\u001b[0m, \u001b[1;36m0.75703312\u001b[0m, \u001b[1;36m0.74991232\u001b[0m, \u001b[1;36m0.68161151\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m,\n", + "\u001b[2;32m│ │ │ │ │ \u001b[0m\u001b[33mattrs\u001b[0m=\u001b[1m{\u001b[0m\u001b[32m'type'\u001b[0m: \u001b[32m'mytype'\u001b[0m\u001b[1m}\u001b[0m\n", + "\u001b[2;32m│ │ │ │ \u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ │ │ \u001b[0m\u001b[33mclass_\u001b[0m=\u001b[32m'A'\u001b[0m\n", + "\u001b[2;32m│ │ \u001b[0m\u001b[1m)\u001b[0m\n", + "\u001b[2;32m│ \u001b[0m\u001b[1m}\u001b[0m,\n", + "\u001b[2;32m│ \u001b[0m\u001b[33mattrs\u001b[0m=\u001b[1m{\u001b[0m\u001b[1m}\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "parse = Datastore.model_validate_hdf5(filepath)\n", + "pprint(parse)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/src/oqd_dataschema/base.py b/src/oqd_dataschema/base.py index d18b689..63e32a0 100644 --- a/src/oqd_dataschema/base.py +++ b/src/oqd_dataschema/base.py @@ -165,6 +165,12 @@ def validate_data_matches_shape_dtype(self): def __getitem__(self, idx): return self.data[idx] + @classmethod + def _is_dataset_type(cls, type_): + return type_ == cls or ( + typing.get_origin(type_) is Annotated and type_.__origin__ is cls + ) + def _constrain_dtype(dataset, *, dtype_constraint=None): if (not isinstance(dtype_constraint, str)) and isinstance( @@ -261,8 +267,14 @@ def __init_subclass__(cls, **kwargs): if ( k not in ["class_", "attrs"] - and v not in [Dataset, ClassVar] - and not (typing.get_origin(v) == Annotated and v.__origin__ is Dataset) + and v is not ClassVar + and not Dataset._is_dataset_type(v) + and not (typing.get_origin(v) is Annotated and v.__origin__ is Dataset) + and not ( + typing.get_origin(v) is dict + and v.__args__[0] is str + and Dataset._is_dataset_type(v.__args__[1]) + ) ): raise TypeError( "All fields of `GroupBase` have to be of type `Dataset`." diff --git a/src/oqd_dataschema/datastore.py b/src/oqd_dataschema/datastore.py index 202736e..78670e8 100644 --- a/src/oqd_dataschema/datastore.py +++ b/src/oqd_dataschema/datastore.py @@ -82,6 +82,37 @@ def validate_groups(cls, data): return data + def _dump_group(self, h5datastore, gkey, group): + if gkey in h5datastore.keys(): + del h5datastore[gkey] + h5_group = h5datastore.create_group(gkey) + + h5_group.attrs["_group_schema"] = json.dumps( + group.model_json_schema(), indent=2 + ) + for akey, attr in group.attrs.items(): + h5_group.attrs[akey] = attr + + for dkey, dataset in group.__dict__.items(): + if isinstance(dataset, dict): + h5_subgroup = h5_group.create_group(dkey) + for ddkey, ddataset in dataset.items(): + self._dump_dataset(h5_subgroup, ddkey, ddataset) + + self._dump_dataset(h5_group, dkey, dataset) + + def _dump_dataset(self, h5group, dkey, dataset): + if isinstance(dataset, Dataset): + if dataset.dtype in "str": + h5_dataset = h5group.create_dataset( + dkey, data=dataset.data.astype(np.dtypes.BytesDType) + ) + else: + h5_dataset = h5group.create_dataset(dkey, data=dataset.data) + + for akey, attr in dataset.attrs.items(): + h5_dataset.attrs[akey] = attr + def model_dump_hdf5(self, filepath: pathlib.Path, mode: Literal["w", "a"] = "w"): """ Saves the model and its associated data to an HDF5 file. @@ -101,29 +132,7 @@ def model_dump_hdf5(self, filepath: pathlib.Path, mode: Literal["w", "a"] = "w") # store each group for gkey, group in self.groups.items(): - if gkey in f.keys(): - del f[gkey] - h5_group = f.create_group(gkey) - - h5_group.attrs["_group_schema"] = json.dumps( - group.model_json_schema(), indent=2 - ) - for akey, attr in group.attrs.items(): - h5_group.attrs[akey] = attr - - for dkey, dataset in group.__dict__.items(): - if not isinstance(dataset, Dataset): - continue - - if dataset.dtype in "str": - h5_dataset = h5_group.create_dataset( - dkey, data=dataset.data.astype(np.dtypes.BytesDType) - ) - else: - h5_dataset = h5_group.create_dataset(dkey, data=dataset.data) - - for akey, attr in dataset.attrs.items(): - h5_dataset.attrs[akey] = attr + self._dump_group(f, gkey, group) @classmethod def model_validate_hdf5(cls, filepath: pathlib.Path): @@ -141,9 +150,20 @@ def model_validate_hdf5(cls, filepath: pathlib.Path): for dkey in group.__class__.model_fields: if dkey in ("attrs", "class_"): continue - group.__dict__[dkey].data = np.array(f[gkey][dkey][()]).astype( - DTypes.get(group.__dict__[dkey].dtype).value - ) + + if isinstance(group.__dict__[dkey], Dataset): + group.__dict__[dkey].data = np.array(f[gkey][dkey][()]).astype( + DTypes.get(group.__dict__[dkey].dtype).value + ) + + if isinstance(group.__dict__[dkey], dict): + for ddkey in group.__dict__[dkey]: + group.__dict__[dkey][ddkey].data = np.array( + f[gkey][dkey][ddkey][()] + ).astype( + DTypes.get(group.__dict__[dkey][ddkey].dtype).value + ) + return self def __getitem__(self, key): From f422b438749bc333143a78fd645b021d27edf7f6 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Thu, 25 Sep 2025 19:24:52 -0400 Subject: [PATCH 10/48] [feat] Implemented flex shape support and refactored validators --- src/oqd_dataschema/__init__.py | 3 +- src/oqd_dataschema/base.py | 175 ++++++++++++++++++--------------- src/oqd_dataschema/utils.py | 43 ++++++++ 3 files changed, 143 insertions(+), 78 deletions(-) create mode 100644 src/oqd_dataschema/utils.py diff --git a/src/oqd_dataschema/__init__.py b/src/oqd_dataschema/__init__.py index 934d47d..74c42db 100644 --- a/src/oqd_dataschema/__init__.py +++ b/src/oqd_dataschema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import Dataset, GroupBase, GroupRegistry, condataset +from .base import CastDataset, Dataset, GroupBase, GroupRegistry, condataset from .datastore import Datastore from .groups import ( ExpectationValueDataGroup, @@ -24,6 +24,7 @@ ######################################################################################## __all__ = [ + "CastDataset", "Dataset", "Datastore", "GroupBase", diff --git a/src/oqd_dataschema/base.py b/src/oqd_dataschema/base.py index 63e32a0..7274a13 100644 --- a/src/oqd_dataschema/base.py +++ b/src/oqd_dataschema/base.py @@ -16,7 +16,7 @@ import typing import warnings from enum import Enum -from functools import partial, reduce +from types import NoneType from typing import Annotated, Any, ClassVar, Literal, Optional, Sequence, Tuple, Union import numpy as np @@ -28,9 +28,12 @@ Discriminator, Field, TypeAdapter, + field_validator, model_validator, ) +from .utils import _flex_shape_equal, _validator_from_condition + ######################################################################################## __all__ = ["GroupBase", "Dataset", "GroupRegistry", "condataset", "CastDataset"] @@ -108,7 +111,7 @@ class Dataset(BaseModel, extra="forbid"): """ dtype: Optional[Literal[DTypes.names()]] = None # type: ignore - shape: Optional[Tuple[int, ...]] = None + shape: Optional[Tuple[Union[int, None], ...]] = None data: Optional[Any] = Field(default=None, exclude=True) attrs: Attrs = {} @@ -117,51 +120,58 @@ class Dataset(BaseModel, extra="forbid"): use_enum_values=False, arbitrary_types_allowed=True, validate_assignment=True ) + @field_validator("data", mode="before") @classmethod - def cast(cls, data): - if isinstance(data, np.ndarray): - return cls(data=data) - return data + def validate_and_update(cls, value): + # check if data exist + if value is None: + return value - @model_validator(mode="before") - @classmethod - def validate_and_update(cls, values: dict): - data = values.get("data") - dtype = values.get("dtype") - shape = values.get("shape") + # check if data is a numpy array + if not isinstance(value, np.ndarray): + raise TypeError("`data` must be a numpy.ndarray.") - if data is None and (dtype is not None and shape is not None): - return values + return value - elif data is not None and (dtype is None and shape is None): - if not isinstance(data, np.ndarray): - raise TypeError("`data` must be a numpy.ndarray.") + @model_validator(mode="after") + def validate_data_matches_shape_dtype(self): + """Ensure that `data` matches `dtype` and `shape`.""" - if type(data.dtype) not in DTypes: - raise TypeError( - f"`data` must be a numpy array of dtype in {tuple(DTypes.names())}." - ) + # check if data exist + if self.data is None: + return self + + # check if dtype matches data + if ( + self.dtype is not None + and type(self.data.dtype) is not DTypes.get(self.dtype).value + ): + raise ValueError( + f"Expected data dtype `{self.dtype}`, but got `{self.data.dtype.name}`." + ) - values["dtype"] = DTypes(type(data.dtype)).name.lower() - values["shape"] = data.shape + # check if shape mataches data + if self.shape is not None and not _flex_shape_equal( + self.data.shape, self.shape + ): + raise ValueError(f"Expected shape {self.shape}, but got {self.data.shape}.") - return values + # reassign dtype if it is None + if self.dtype != DTypes(type(self.data.dtype)).name.lower(): + self.dtype = DTypes(type(self.data.dtype)).name.lower() + + # resassign shape to concrete value if it is None or a flexible shape + if self.shape != self.data.shape: + self.shape = self.data.shape - @model_validator(mode="after") - def validate_data_matches_shape_dtype(self): - """Ensure that `data` matches `dtype` and `shape`.""" - if self.data is not None: - expected_dtype = DTypes.get(self.dtype).value - if type(self.data.dtype) is not expected_dtype: - raise ValueError( - f"Expected data dtype `{self.dtype}`, but got `{self.data.dtype.name}`." - ) - if self.data.shape != self.shape: - raise ValueError( - f"Expected shape {self.shape}, but got {self.data.shape}." - ) return self + @classmethod + def cast(cls, data): + if isinstance(data, np.ndarray): + return cls(data=data) + return data + def __getitem__(self, idx): return self.data[idx] @@ -172,6 +182,12 @@ def _is_dataset_type(cls, type_): ) +CastDataset = Annotated[Dataset, BeforeValidator(Dataset.cast)] + +######################################################################################## + + +@_validator_from_condition def _constrain_dtype(dataset, *, dtype_constraint=None): if (not isinstance(dtype_constraint, str)) and isinstance( dtype_constraint, Sequence @@ -185,10 +201,12 @@ def _constrain_dtype(dataset, *, dtype_constraint=None): f"Expected dtype to be of type one of {dtype_constraint}, but got {dataset.dtype}." ) - return dataset - +@_validator_from_condition def _constraint_dim(dataset, *, min_dim=None, max_dim=None): + if min_dim is not None and max_dim is not None and min_dim > max_dim: + raise ValueError("Impossible to satisfy dimension constraints on dataset.") + min_dim = 0 if min_dim is None else min_dim dims = len(dataset.shape) @@ -198,42 +216,26 @@ def _constraint_dim(dataset, *, min_dim=None, max_dim=None): f"Expected {min_dim} <= dimension of shape{f' <= {max_dim}'}, but got shape = {dataset.shape}." ) - return dataset - +@_validator_from_condition def _constraint_shape(dataset, *, shape_constraint=None): - if shape_constraint and ( - len(shape_constraint) != len(dataset.shape) - or reduce( - lambda x, y: x or y, - map( - lambda x: x[0] is not None and x[0] != x[1], - zip(shape_constraint, dataset.shape), - ), - ) - ): + if shape_constraint and not _flex_shape_equal(shape_constraint, dataset.shape): raise ValueError( f"Expected shape to be {shape_constraint}, but got {dataset.shape}." ) - return dataset - def condataset( *, shape_constraint=None, dtype_constraint=None, min_dim=None, max_dim=None ): return Annotated[ - Dataset, - BeforeValidator(Dataset.cast), - AfterValidator(partial(_constrain_dtype, dtype_constraint=dtype_constraint)), - AfterValidator(partial(_constraint_dim, min_dim=min_dim, max_dim=max_dim)), - AfterValidator(partial(_constraint_shape, shape_constraint=shape_constraint)), + CastDataset, + AfterValidator(_constrain_dtype(dtype_constraint=dtype_constraint)), + AfterValidator(_constraint_dim(min_dim=min_dim, max_dim=max_dim)), + AfterValidator(_constraint_shape(shape_constraint=shape_constraint)), ] -CastDataset = Annotated[Dataset, BeforeValidator(Dataset.cast)] - - ######################################################################################## @@ -255,27 +257,46 @@ class GroupBase(BaseModel, extra="forbid"): attrs: Attrs = {} + @classmethod + def _is_allowed_field_type(cls, v): + is_dataset = Dataset._is_dataset_type(v) + + is_annotated_dataset = typing.get_origin( + v + ) is Annotated and Dataset._is_dataset_type(v.__origin__) + + is_optional_dataset = typing.get_origin(v) is Union and ( + (v.__args__[0] == NoneType and Dataset._is_dataset_type(v.__args__[1])) + or (v.__args__[1] == NoneType and Dataset._is_dataset_type(v.__args__[0])) + ) + + is_dict_dataset = ( + typing.get_origin(v) is dict + and v.__args__[0] is str + and Dataset._is_dataset_type(v.__args__[1]) + ) + + return ( + is_dataset or is_annotated_dataset or is_optional_dataset or is_dict_dataset + ) + + @classmethod + def _is_classvar(cls, v): + return v is ClassVar or typing.get_origin(v) is ClassVar + def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) for k, v in cls.__annotations__.items(): - if k == "class_": - raise AttributeError("`class_` attribute should not be set manually.") - - if k == "attrs" and k is not Attrs: - raise TypeError("`attrs` should be of type `Attrs`") - - if ( - k not in ["class_", "attrs"] - and v is not ClassVar - and not Dataset._is_dataset_type(v) - and not (typing.get_origin(v) is Annotated and v.__origin__ is Dataset) - and not ( - typing.get_origin(v) is dict - and v.__args__[0] is str - and Dataset._is_dataset_type(v.__args__[1]) + if k in ["class_", "attrs"]: + raise AttributeError( + "`class_` and `attrs` attribute should not be set manually." ) - ): + + if cls._is_classvar(v): + continue + + if not cls._is_allowed_field_type(v): raise TypeError( "All fields of `GroupBase` have to be of type `Dataset`." ) diff --git a/src/oqd_dataschema/utils.py b/src/oqd_dataschema/utils.py new file mode 100644 index 0000000..90832c5 --- /dev/null +++ b/src/oqd_dataschema/utils.py @@ -0,0 +1,43 @@ +# Copyright 2024-2025 Open Quantum Design + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import reduce + +######################################################################################## + +__all__ = ["_flex_shape_equal", "_validator_from_condition"] + + +######################################################################################## + + +def _flex_shape_equal(shape1, shape2): + return len(shape1) == len(shape2) and reduce( + lambda x, y: x and y, + map( + lambda x: x[0] is None or x[1] is None or x[0] == x[1], + zip(shape1, shape2), + ), + ) + + +def _validator_from_condition(f): + def _wrapped_validator(*args, **kwargs): + def _wrapped_condition(model): + f(model, *args, **kwargs) + return model + + return _wrapped_condition + + return _wrapped_validator From 92a57198aec22180be5d7799d0a47e3687557b09 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Fri, 26 Sep 2025 14:52:21 -0400 Subject: [PATCH 11/48] [test] Added more test of datasets, datastores and groups --- tests/test_dataset.py | 249 ++++++++++++++++++++++++++++++++++++ tests/test_datastore.py | 109 ++++++++++++---- tests/test_group.py | 170 ++++++++++++++++++++++++ tests/test_groupregistry.py | 77 +++++++++++ tests/test_typeadapt.py | 51 -------- 5 files changed, 577 insertions(+), 79 deletions(-) create mode 100644 tests/test_dataset.py create mode 100644 tests/test_group.py create mode 100644 tests/test_groupregistry.py delete mode 100644 tests/test_typeadapt.py diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 0000000..a1a09a0 --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,249 @@ +# Copyright 2024-2025 Open Quantum Design + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% + +import numpy as np +import pytest +from pydantic import TypeAdapter + +from oqd_dataschema.base import CastDataset, Dataset, DTypes, condataset + +######################################################################################## + + +class TestDatasetDtype: + @pytest.mark.parametrize( + ("dtype", "np_dtype"), + [ + ("bool", np.dtypes.BoolDType), + ("int16", np.dtypes.Int16DType), + ("int32", np.dtypes.Int32DType), + ("int64", np.dtypes.Int64DType), + ("uint16", np.dtypes.UInt16DType), + ("uint32", np.dtypes.UInt32DType), + ("uint64", np.dtypes.UInt64DType), + ("float16", np.dtypes.Float16DType), + ("float32", np.dtypes.Float32DType), + ("float64", np.dtypes.Float64DType), + ("complex64", np.dtypes.Complex64DType), + ("complex128", np.dtypes.Complex128DType), + ("str", np.dtypes.StrDType), + ("bytes", np.dtypes.BytesDType), + ("string", np.dtypes.StringDType), + ], + ) + def test_dtypes(self, dtype, np_dtype): + ds = Dataset(dtype=dtype, shape=(100,)) + + data = np.random.rand(100).astype(np_dtype) + ds.data = data + + @pytest.mark.xfail(raises=ValueError) + @pytest.mark.parametrize("dtype", list(DTypes.names())) + def test_unmatched_dtype_data(self, dtype): + ds = Dataset(dtype=dtype, shape=(100,)) + + data = np.random.rand(100).astype("O") + ds.data = data + + @pytest.mark.parametrize("dtype", list(DTypes.names())) + def test_flexible_dtype(self, dtype): + ds = Dataset(dtype=None, shape=(100,)) + + data = np.random.rand(100).astype(DTypes.get(dtype).value) + ds.data = data + + assert ds.dtype == DTypes(type(ds.data.dtype)).name.lower() + + def test_dtype_mutation(self): + ds = Dataset(dtype="float32", shape=(100,)) + + ds.dtype = "float64" + + data = np.random.rand(100) + ds.data = data + + +class TestDatasetShape: + @pytest.mark.xfail(raises=ValueError) + @pytest.mark.parametrize( + "shape", + [ + (0,), + (1,), + (99,), + (1, 1), + ], + ) + def test_unmatched_shape_data(self, shape): + ds = Dataset(dtype="float64", shape=(100,)) + + data = np.random.rand(*shape) + ds.data = data + + @pytest.mark.parametrize( + ("shape", "data_shape"), + [ + ((None,), (0,)), + ((None,), (1,)), + ((None,), (100,)), + ((None, 0), (0, 0)), + ((None, 1), (1, 1)), + ((None, None), (1, 1)), + ((None, None), (10, 100)), + ((None, None, 1), (1, 1, 1)), + ], + ) + def test_flexible_shape(self, shape, data_shape): + ds = Dataset(dtype="float64", shape=shape) + + data = np.random.rand(*data_shape) + ds.data = data + + assert ds.shape == ds.data.shape + + def test_shape_mutation(self): + ds = Dataset(dtype="float64", shape=(1,)) + + ds.shape = (100,) + + data = np.random.rand(100) + ds.data = data + + +class TestCastDataset: + @pytest.fixture + def adapter(self): + return TypeAdapter(CastDataset) + + @pytest.mark.parametrize( + ("data", "dtype", "shape"), + [ + (np.random.rand(100), "float64", (100,)), + (np.random.rand(10).astype("str"), "str", (10,)), + (np.random.rand(1, 10, 100).astype("bytes"), "bytes", (1, 10, 100)), + ], + ) + def test_cast(self, adapter, data, shape, dtype): + ds = adapter.validate_python(data) + + assert ds.shape == shape and ds.dtype == dtype + + +class TestConstrainedDataset: + @pytest.mark.parametrize( + ("cds", "data"), + [ + (condataset(dtype_constraint="float64"), np.random.rand(10)), + (condataset(dtype_constraint="str"), np.random.rand(10).astype(str)), + ( + condataset(dtype_constraint=("float16", "float32", "float64")), + np.random.rand(10), + ), + ( + condataset(dtype_constraint=("float16", "float32", "float64")), + np.random.rand(10).astype("float16"), + ), + ( + condataset(dtype_constraint=("float16", "float32", "float64")), + np.random.rand(10).astype("float32"), + ), + ], + ) + def test_constrained_dataset_dtype(self, cds, data): + adapter = TypeAdapter(cds) + + adapter.validate_python(data) + + @pytest.mark.xfail(raises=ValueError) + @pytest.mark.parametrize( + ("cds", "data"), + [ + (condataset(dtype_constraint="float64"), np.random.rand(10).astype(str)), + (condataset(dtype_constraint="str"), np.random.rand(10)), + ( + condataset(dtype_constraint=("float16", "float32", "float64")), + np.random.rand(10).astype(str), + ), + ], + ) + def test_violate_dtype_constraint(self, cds, data): + adapter = TypeAdapter(cds) + + adapter.validate_python(data) + + @pytest.mark.parametrize( + ("cds", "data"), + [ + (condataset(min_dim=1, max_dim=1), np.random.rand(10)), + (condataset(min_dim=0, max_dim=1), np.random.rand(10)), + (condataset(max_dim=2), np.random.rand(10)), + (condataset(max_dim=3), np.random.rand(10, 10, 10)), + (condataset(min_dim=2), np.random.rand(10, 10)), + (condataset(min_dim=2), np.random.rand(10, 10, 10, 10, 10)), + (condataset(min_dim=2, max_dim=4), np.random.rand(10, 10, 10, 10)), + (condataset(min_dim=2, max_dim=4), np.random.rand(10, 10, 10)), + (condataset(min_dim=2, max_dim=4), np.random.rand(10, 10)), + ], + ) + def test_constrained_dataset_dimension(self, cds, data): + adapter = TypeAdapter(cds) + + adapter.validate_python(data) + + @pytest.mark.xfail(raises=ValueError) + @pytest.mark.parametrize( + ("cds", "data"), + [ + (condataset(min_dim=1, max_dim=1), np.random.rand(10, 10)), + (condataset(min_dim=2, max_dim=3), np.random.rand(10)), + (condataset(min_dim=2, max_dim=3), np.random.rand(10, 10, 10, 10)), + ], + ) + def test_violate_dimension_constraint(self, cds, data): + adapter = TypeAdapter(cds) + + adapter.validate_python(data) + + @pytest.mark.parametrize( + ("cds", "data"), + [ + (condataset(shape_constraint=(None,)), np.random.rand(10)), + (condataset(shape_constraint=(10,)), np.random.rand(10)), + (condataset(shape_constraint=(None, None)), np.random.rand(1, 2)), + (condataset(shape_constraint=(1, None)), np.random.rand(1, 2)), + (condataset(shape_constraint=(1, 2)), np.random.rand(1, 2)), + (condataset(shape_constraint=(1, None, 3)), np.random.rand(1, 10, 3)), + ], + ) + def test_constrained_dataset_shape(self, cds, data): + adapter = TypeAdapter(cds) + + adapter.validate_python(data) + + @pytest.mark.xfail(raises=ValueError) + @pytest.mark.parametrize( + ("cds", "data"), + [ + (condataset(shape_constraint=(1,)), np.random.rand(10)), + (condataset(shape_constraint=(None,)), np.random.rand(10, 10)), + (condataset(shape_constraint=(None, 1)), np.random.rand(10, 10)), + (condataset(shape_constraint=(None, 1)), np.random.rand(1, 10)), + ], + ) + def test_violate_shape_constraint(self, cds, data): + adapter = TypeAdapter(cds) + + adapter.validate_python(data) diff --git a/tests/test_datastore.py b/tests/test_datastore.py index 384d3fe..c39f49e 100644 --- a/tests/test_datastore.py +++ b/tests/test_datastore.py @@ -13,44 +13,97 @@ # limitations under the License. # %% -import pathlib +import uuid +from typing import Dict, Optional import numpy as np import pytest -from oqd_dataschema.base import Dataset, DTypes -from oqd_dataschema.datastore import Datastore -from oqd_dataschema.groups import ( - SinaraRawDataGroup, -) - +from oqd_dataschema import Datastore, GroupBase +from oqd_dataschema.base import Dataset # %% -@pytest.mark.parametrize( - "dtype", - [ - "int32", - "int64", - "float32", - "float64", - "complex64", - "complex128", - ], + +_Group = type( + f"_Group_{uuid.uuid4()}".replace("-", ""), + (GroupBase,), + { + "__annotations__": { + "x": Dataset, + "y": Dict[str, Dataset], + "z": Optional[Dataset], + }, + "y": {}, + "z": None, + }, ) -def test_serialize_deserialize(dtype): - data = np.ones([10, 10]).astype(dtype) - dataset = SinaraRawDataGroup(camera_images=Dataset(data=data)) - data = Datastore(groups={"test": dataset}) - filepath = pathlib.Path("test.h5") - data.model_dump_hdf5(filepath) - data_reload = Datastore.model_validate_hdf5(filepath) +class TestDatastore: + @pytest.mark.parametrize( + ("dtype", "np_dtype"), + [ + ("bool", np.dtypes.BoolDType), + ("int16", np.dtypes.Int16DType), + ("int32", np.dtypes.Int32DType), + ("int64", np.dtypes.Int64DType), + ("uint16", np.dtypes.UInt16DType), + ("uint32", np.dtypes.UInt32DType), + ("uint64", np.dtypes.UInt64DType), + ("float16", np.dtypes.Float16DType), + ("float32", np.dtypes.Float32DType), + ("float64", np.dtypes.Float64DType), + ("complex64", np.dtypes.Complex64DType), + ("complex128", np.dtypes.Complex128DType), + ("str", np.dtypes.StrDType), + ("bytes", np.dtypes.BytesDType), + ("string", np.dtypes.StringDType), + ], + ) + def test_serialize_deserialize_dtypes(self, dtype, np_dtype, tmp_path): + f = tmp_path / f"tmp{uuid.uuid4()}.h5" + + datastore = Datastore( + groups={"g1": _Group(x=Dataset(data=np.random.rand(1).astype(np_dtype)))} + ) - assert ( - type(data_reload.groups["test"].camera_images.data.dtype) - is DTypes.get(dtype).value + datastore.model_dump_hdf5(f) + + Datastore.model_validate_hdf5(f) + + @pytest.mark.parametrize( + ("x", "y", "z"), + [ + ( + Dataset(data=np.random.rand(10)), + {}, + None, + ), + ( + Dataset(data=np.random.rand(10)), + {"f1": Dataset(data=np.random.rand(10))}, + None, + ), + ( + Dataset(data=np.random.rand(10)), + {"f1": Dataset(data=np.random.rand(10))}, + Dataset(data=np.random.rand(10)), + ), + ( + Dataset(data=np.random.rand(10)), + { + "f1": Dataset(data=np.random.rand(10)), + "f2": Dataset(data=np.random.rand(10)), + }, + Dataset(data=np.random.rand(10)), + ), + ], ) + def test_serialize_deserialize_dataset_types(self, x, y, z, tmp_path): + f = tmp_path / f"tmp{uuid.uuid4()}.h5" + datastore = Datastore(groups={"g1": _Group(x=x, y=y, z=z)}) -# %% + datastore.model_dump_hdf5(f) + + Datastore.model_validate_hdf5(f) diff --git a/tests/test_group.py b/tests/test_group.py new file mode 100644 index 0000000..ca87fe6 --- /dev/null +++ b/tests/test_group.py @@ -0,0 +1,170 @@ +# Copyright 2024-2025 Open Quantum Design + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% + +import uuid +from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple + +import numpy as np +import pytest + +from oqd_dataschema import CastDataset, Dataset, GroupBase, condataset + +######################################################################################## + + +class TestGroupDefinition: + @pytest.mark.parametrize( + "field_type", + [ + Dataset, + CastDataset, + Dict[str, Dataset], + Dict[str, CastDataset], + condataset(dtype_constraint="float32"), + condataset(dtype_constraint=("float16", "float32", "float64")), + condataset(min_dim=1), + condataset(max_dim=1), + condataset(min_dim=1, max_dim=2), + condataset(shape_constraint=(1,)), + condataset(shape_constraint=(None,)), + condataset(shape_constraint=(None, 1)), + condataset(shape_constraint=(None, None)), + Optional[Dataset], + ], + ) + def test_data_field_definition(self, field_type): + type( + f"_Group_{uuid.uuid4()}".replace("-", ""), + (GroupBase,), + {"__annotations__": {"x": field_type}}, + ) + + @pytest.mark.xfail(raises=TypeError) + @pytest.mark.parametrize( + "field_type", + [ + Any, + int, + List[int], + Tuple[int], + List[Dataset], + Tuple[Dataset], + Dict[int, Dataset], + ], + ) + def test_invalid_data_field_definition(self, field_type): + type( + f"_Group_{uuid.uuid4()}".replace("-", ""), + (GroupBase,), + {"__annotations__": {"x": field_type}}, + ) + + @pytest.mark.xfail(raises=AttributeError) + def test_overwriting_attrs(self): + type( + f"_Group_{uuid.uuid4()}".replace("-", ""), + (GroupBase,), + {"__annotations__": {"attrs": Dict[str, Any]}}, + ) + + @pytest.mark.xfail(raises=AttributeError) + def test_overwriting_class_(self): + groupname = f"_Group_{uuid.uuid4()}".replace("-", "") + type( + groupname, + (GroupBase,), + {"__annotations__": {"class_": Literal[groupname]}}, + ) + + @pytest.mark.parametrize( + ("field_type", "data"), + [ + (Dataset, Dataset(data=np.random.rand(100))), + (CastDataset, Dataset(data=np.random.rand(100))), + ( + Dict[str, Dataset], + { + "1": Dataset(data=np.random.rand(100)), + "2": Dataset(data=np.random.rand(100)), + }, + ), + ( + Dict[str, CastDataset], + { + "1": Dataset(data=np.random.rand(100)), + "2": Dataset(data=np.random.rand(100)), + }, + ), + (condataset(dtype_constraint="float64"), Dataset(data=np.random.rand(100))), + ( + condataset(dtype_constraint=("float16", "float32", "float64")), + Dataset(data=np.random.rand(100)), + ), + (Optional[Dataset], Dataset(data=np.random.rand(100))), + (Optional[Dataset], None), + ], + ) + def test_group_instantiation(self, field_type, data): + _Group = type( + f"_Group_{uuid.uuid4()}".replace("-", ""), + (GroupBase,), + {"__annotations__": {"x": field_type}}, + ) + + _Group(x=data) + + @pytest.mark.parametrize( + ("classvar_type"), + [ + ClassVar, + ClassVar[int], + ], + ) + def test_class_variable(self, classvar_type): + type( + f"_Group_{uuid.uuid4()}".replace("-", ""), + (GroupBase,), + {"__annotations__": {"x": classvar_type}}, + ) + + @pytest.mark.parametrize( + ("dataset"), + [ + Dataset(), + Dataset(data=np.random.rand(10)), + Dataset(dtype="float64", shape=(10,)), + Dataset(dtype="float64", shape=(10,), data=np.random.rand(10)), + ], + ) + def test_default_dataset(self, dataset): + _Group = type( + f"_Group_{uuid.uuid4()}".replace("-", ""), + (GroupBase,), + {"__annotations__": {"x": Dataset}, "x": dataset}, + ) + + g = _Group() + + assert ( + ( + (g.x.data == dataset.data).all() + and g.x.dtype == dataset.dtype + and g.x.shape == dataset.shape + and g.x.attrs == dataset.attrs + ) + if isinstance(dataset.data, np.ndarray) + else g.x == dataset + ) diff --git a/tests/test_groupregistry.py b/tests/test_groupregistry.py new file mode 100644 index 0000000..1425781 --- /dev/null +++ b/tests/test_groupregistry.py @@ -0,0 +1,77 @@ +# Copyright 2024-2025 Open Quantum Design + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# %% + +import pytest + +from oqd_dataschema.base import ( + CastDataset, + Dataset, + GroupBase, + GroupRegistry, + condataset, +) + + +class TestGroupRegistry: + def test_clear(self): + GroupRegistry.clear() + + GroupRegistry.groups = dict() + + def test_add_group(self): + GroupRegistry.clear() + + groups = set() + for k in "ABCDE": + groups.add( + type(f"_Group{k}", (GroupBase,), {"__annotations__": {"x": Dataset}}) + ) + + assert set(GroupRegistry.groups.values()) == groups + + def test_overwrite_group(self): + GroupRegistry.clear() + + _GroupA = type("_GroupA", (GroupBase,), {"__annotations__": {"x": Dataset}}) + + assert set(GroupRegistry.groups.values()) == {_GroupA} + + with pytest.warns(UserWarning): + _mGroupA = type( + "_GroupA", (GroupBase,), {"__annotations__": {"x": CastDataset}} + ) + + assert set(GroupRegistry.groups.values()) == {_mGroupA} + + @pytest.fixture + def group_generator(self): + def _groupgen(): + groups = [] + for k, dtype in zip( + "ABCDE", + ["str", "float64", "bytes", "bool", ("int16", "int32", "int64")], + ): + groups.append( + type( + f"_Group{k}", + (GroupBase,), + {"__annotations__": {"x": condataset(dtype_constraint=dtype)}}, + ) + ) + return groups + + return _groupgen diff --git a/tests/test_typeadapt.py b/tests/test_typeadapt.py deleted file mode 100644 index 609f09c..0000000 --- a/tests/test_typeadapt.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2024-2025 Open Quantum Design - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# %% -import pathlib - -import numpy as np - -from oqd_dataschema.base import Dataset, GroupBase -from oqd_dataschema.datastore import Datastore -from oqd_dataschema.groups import ( - SinaraRawDataGroup, -) - - -# %% -def test_adapt(): - class TestNewGroup(GroupBase): - """ """ - - array: Dataset - - filepath = pathlib.Path("test.h5") - - data = np.ones([10, 10]).astype("int64") - group1 = TestNewGroup(array=Dataset(data=data)) - - data = np.ones([10, 10]).astype("int32") - group2 = SinaraRawDataGroup(camera_images=Dataset(data=data)) - - datastore = Datastore( - groups={ - "group1": group1, - "group2": group2, - } - ) - datastore.model_dump_hdf5(filepath, mode="w") - - Datastore.model_validate_hdf5(filepath) From 11abee3c7f292df50b7ed7e703c748538afdcb46 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Fri, 26 Sep 2025 15:14:38 -0400 Subject: [PATCH 12/48] [refactor] refactor datastore group validation and added comments --- src/oqd_dataschema/datastore.py | 94 ++++++++++++++++++--------------- 1 file changed, 52 insertions(+), 42 deletions(-) diff --git a/src/oqd_dataschema/datastore.py b/src/oqd_dataschema/datastore.py index 78670e8..9b43f30 100644 --- a/src/oqd_dataschema/datastore.py +++ b/src/oqd_dataschema/datastore.py @@ -49,69 +49,69 @@ class Datastore(BaseModel, extra="forbid"): attrs: Attrs = {} - @field_validator("groups", mode="before") @classmethod - def validate_groups(cls, data): - if isinstance(data, dict): - # Get the current adapter from registry - try: - validated_groups = {} - - for key, group_data in data.items(): - if isinstance(group_data, GroupBase): - # Already a Group instance - validated_groups[key] = group_data - elif isinstance(group_data, dict): - # Parse dict using discriminated union - validated_groups[key] = GroupRegistry.adapter.validate_python( - group_data - ) - else: - raise ValueError( - f"Invalid group data for key '{key}': {type(group_data)}" - ) + def _validate_group(cls, key, group): + if isinstance(group, GroupBase): + return group - data = validated_groups + if isinstance(group, dict): + return GroupRegistry.adapter.validate_python(group) - except ValueError as e: - if "No group types registered" in str(e): - raise ValueError( - "No group types available. Register group types before creating Datastore." - ) - raise + raise ValueError(f"Key `{key}` contains invalid group data.") - return data + @field_validator("groups", mode="before") + @classmethod + def validate_groups(cls, data): + if GroupRegistry.groups == {}: + raise ValueError( + "No group types available. Register group types before creating Datastore." + ) + + validated_groups = {k: cls._validate_group(k, v) for k, v in data.items()} + return validated_groups def _dump_group(self, h5datastore, gkey, group): + # remove existing group if gkey in h5datastore.keys(): del h5datastore[gkey] + + # create group h5_group = h5datastore.create_group(gkey) + # dump group schema h5_group.attrs["_group_schema"] = json.dumps( group.model_json_schema(), indent=2 ) + + # dump group attributes for akey, attr in group.attrs.items(): h5_group.attrs[akey] = attr + # dump group data for dkey, dataset in group.__dict__.items(): + # if group field contain dictionary of Dataset if isinstance(dataset, dict): h5_subgroup = h5_group.create_group(dkey) for ddkey, ddataset in dataset.items(): self._dump_dataset(h5_subgroup, ddkey, ddataset) - - self._dump_dataset(h5_group, dkey, dataset) + else: + self._dump_dataset(h5_group, dkey, dataset) def _dump_dataset(self, h5group, dkey, dataset): - if isinstance(dataset, Dataset): - if dataset.dtype in "str": - h5_dataset = h5group.create_dataset( - dkey, data=dataset.data.astype(np.dtypes.BytesDType) - ) - else: - h5_dataset = h5group.create_dataset(dkey, data=dataset.data) + if not isinstance(dataset, Dataset): + raise ValueError("Group data field is not a Dataset.") + + # dtype str converted to bytes when dumped (h5 compatibility) + if dataset.dtype in "str": + h5_dataset = h5group.create_dataset( + dkey, data=dataset.data.astype(np.dtypes.BytesDType) + ) + else: + h5_dataset = h5group.create_dataset(dkey, data=dataset.data) - for akey, attr in dataset.attrs.items(): - h5_dataset.attrs[akey] = attr + # dump dataset attributes + for akey, attr in dataset.attrs.items(): + h5_dataset.attrs[akey] = attr def model_dump_hdf5(self, filepath: pathlib.Path, mode: Literal["w", "a"] = "w"): """ @@ -125,12 +125,12 @@ def model_dump_hdf5(self, filepath: pathlib.Path, mode: Literal["w", "a"] = "w") filepath.parent.mkdir(exist_ok=True, parents=True) with h5py.File(filepath, mode) as f: - # store the model JSON schema + # dump the datastore signature f.attrs["_datastore_signature"] = self.model_dump_json(indent=2) for akey, attr in self.attrs.items(): f.attrs[akey] = attr - # store each group + # dump each group for gkey, group in self.groups.items(): self._dump_group(f, gkey, group) @@ -143,19 +143,24 @@ def model_validate_hdf5(cls, filepath: pathlib.Path): filepath (pathlib.Path): The path to the HDF5 file where the model data will be read and validated from. """ with h5py.File(filepath, "r") as f: + # Load datastore signature self = cls.model_validate_json(f.attrs["_datastore_signature"]) - # loop through all groups in the model schema and load HDF5 store + # loop through all groups in the model schema and load the data for gkey, group in self: for dkey in group.__class__.model_fields: + # ignore attrs and class_ fields if dkey in ("attrs", "class_"): continue + # load Dataset data if isinstance(group.__dict__[dkey], Dataset): group.__dict__[dkey].data = np.array(f[gkey][dkey][()]).astype( DTypes.get(group.__dict__[dkey].dtype).value ) + continue + # load data for dict of Dataset if isinstance(group.__dict__[dkey], dict): for ddkey in group.__dict__[dkey]: group.__dict__[dkey][ddkey].data = np.array( @@ -163,6 +168,11 @@ def model_validate_hdf5(cls, filepath: pathlib.Path): ).astype( DTypes.get(group.__dict__[dkey][ddkey].dtype).value ) + continue + + raise TypeError( + "Group data fields must be of type Dataset or dict of Dataset." + ) return self From 1d82e3eac0e9568f7c7f80c30b1705992cb36a71 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Thu, 2 Oct 2025 13:49:16 -0400 Subject: [PATCH 13/48] [fix, feat] Added method to Datastore to add and update groups. fixed datastore model_dump_hdf5 ignore attrs and class_ when iterating through fields --- src/oqd_dataschema/base.py | 13 +++++++++++ src/oqd_dataschema/datastore.py | 41 ++++++++++++++++++++++++++++----- src/oqd_dataschema/utils.py | 3 +++ 3 files changed, 51 insertions(+), 6 deletions(-) diff --git a/src/oqd_dataschema/base.py b/src/oqd_dataschema/base.py index 7274a13..767382d 100644 --- a/src/oqd_dataschema/base.py +++ b/src/oqd_dataschema/base.py @@ -189,6 +189,7 @@ def _is_dataset_type(cls, type_): @_validator_from_condition def _constrain_dtype(dataset, *, dtype_constraint=None): + """Constrains the dtype of a dataset""" if (not isinstance(dtype_constraint, str)) and isinstance( dtype_constraint, Sequence ): @@ -204,6 +205,7 @@ def _constrain_dtype(dataset, *, dtype_constraint=None): @_validator_from_condition def _constraint_dim(dataset, *, min_dim=None, max_dim=None): + """Constrains the dimension of a dataset""" if min_dim is not None and max_dim is not None and min_dim > max_dim: raise ValueError("Impossible to satisfy dimension constraints on dataset.") @@ -219,6 +221,7 @@ def _constraint_dim(dataset, *, min_dim=None, max_dim=None): @_validator_from_condition def _constraint_shape(dataset, *, shape_constraint=None): + """Constrains the shape of a dataset""" if shape_constraint and not _flex_shape_equal(shape_constraint, dataset.shape): raise ValueError( f"Expected shape to be {shape_constraint}, but got {dataset.shape}." @@ -228,6 +231,7 @@ def _constraint_shape(dataset, *, shape_constraint=None): def condataset( *, shape_constraint=None, dtype_constraint=None, min_dim=None, max_dim=None ): + """Implements dtype, dimension and shape constrains on the dataset.""" return Annotated[ CastDataset, AfterValidator(_constrain_dtype(dtype_constraint=dtype_constraint)), @@ -312,11 +316,16 @@ def __init_subclass__(cls, **kwargs): class MetaGroupRegistry(type): + """ + Metaclass for the GroupRegistry + """ + def __new__(cls, clsname, superclasses, attributedict): attributedict["groups"] = dict() return super().__new__(cls, clsname, superclasses, attributedict) def register(cls, group): + """Registers a group into the GroupRegistry.""" if not issubclass(group, GroupBase): raise TypeError("You may only register subclasses of GroupBase.") @@ -347,6 +356,10 @@ def adapter(cls): class GroupRegistry(metaclass=MetaGroupRegistry): + """ + Represents the GroupRegistry + """ + pass diff --git a/src/oqd_dataschema/datastore.py b/src/oqd_dataschema/datastore.py index 9b43f30..8d54ff8 100644 --- a/src/oqd_dataschema/datastore.py +++ b/src/oqd_dataschema/datastore.py @@ -37,12 +37,11 @@ # %% class Datastore(BaseModel, extra="forbid"): """ - Saves the model and its associated data to an HDF5 file. - This method serializes the model's data and attributes into an HDF5 file - at the specified filepath. + Class representing a datastore with restricted HDF5 format. Attributes: - filepath (pathlib.Path): The path to the HDF5 file where the model data will be saved. + groups (Dict[str,Group]): groups of data. + attrs (Attrs): attributes of the datastore. """ groups: Dict[str, Any] @@ -51,6 +50,7 @@ class Datastore(BaseModel, extra="forbid"): @classmethod def _validate_group(cls, key, group): + """Helper function for validating group to be of type Group registered in the GroupRegistry.""" if isinstance(group, GroupBase): return group @@ -62,6 +62,7 @@ def _validate_group(cls, key, group): @field_validator("groups", mode="before") @classmethod def validate_groups(cls, data): + """Validates groups to be of type Group registered in the GroupRegistry.""" if GroupRegistry.groups == {}: raise ValueError( "No group types available. Register group types before creating Datastore." @@ -71,6 +72,7 @@ def validate_groups(cls, data): return validated_groups def _dump_group(self, h5datastore, gkey, group): + """Helper function for dumping Group.""" # remove existing group if gkey in h5datastore.keys(): del h5datastore[gkey] @@ -89,15 +91,20 @@ def _dump_group(self, h5datastore, gkey, group): # dump group data for dkey, dataset in group.__dict__.items(): + if dkey in ["attr", "class_"]: + continue + # if group field contain dictionary of Dataset if isinstance(dataset, dict): h5_subgroup = h5_group.create_group(dkey) for ddkey, ddataset in dataset.items(): self._dump_dataset(h5_subgroup, ddkey, ddataset) - else: - self._dump_dataset(h5_group, dkey, dataset) + continue + + self._dump_dataset(h5_group, dkey, dataset) def _dump_dataset(self, h5group, dkey, dataset): + """Helper function for dumping Dataset.""" if not isinstance(dataset, Dataset): raise ValueError("Group data field is not a Dataset.") @@ -132,6 +139,9 @@ def model_dump_hdf5(self, filepath: pathlib.Path, mode: Literal["w", "a"] = "w") # dump each group for gkey, group in self.groups.items(): + if gkey in ["attr", "class_"]: + continue + self._dump_group(f, gkey, group) @classmethod @@ -177,7 +187,26 @@ def model_validate_hdf5(cls, filepath: pathlib.Path): return self def __getitem__(self, key): + """Overloads indexing to retrieve elements in groups.""" return self.groups.__getitem__(key) def __iter__(self): + """Overloads iter to iterate over elements in groups.""" return self.groups.items().__iter__() + + def add(self, **groups): + """Adds a new groups to the datastore.""" + for k, v in groups.items(): + if k in self.groups.keys(): + raise ValueError( + "Key already exist in the datastore, use `update` instead if intending to overwrite past data." + ) + self.groups[k] = v + + def update(self, **groups): + """Updates groups in the datastore, overwriting past values.""" + for k, v in groups.items(): + self.groups[k] = v + + +# %% diff --git a/src/oqd_dataschema/utils.py b/src/oqd_dataschema/utils.py index 90832c5..74ae782 100644 --- a/src/oqd_dataschema/utils.py +++ b/src/oqd_dataschema/utils.py @@ -23,6 +23,7 @@ def _flex_shape_equal(shape1, shape2): + """Helper function for comparing concrete and flex shapes.""" return len(shape1) == len(shape2) and reduce( lambda x, y: x and y, map( @@ -33,6 +34,8 @@ def _flex_shape_equal(shape1, shape2): def _validator_from_condition(f): + """Helper decorator for turning a condition into a validation.""" + def _wrapped_validator(*args, **kwargs): def _wrapped_condition(model): f(model, *args, **kwargs) From 7ae3459fb11e08b34105acb067b400afcf9f7f85 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Thu, 2 Oct 2025 13:59:21 -0400 Subject: [PATCH 14/48] [fix] bug with Datastore.add, it may partially update groups before throwing error --- src/oqd_dataschema/datastore.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/oqd_dataschema/datastore.py b/src/oqd_dataschema/datastore.py index 8d54ff8..13073b7 100644 --- a/src/oqd_dataschema/datastore.py +++ b/src/oqd_dataschema/datastore.py @@ -194,19 +194,21 @@ def __iter__(self): """Overloads iter to iterate over elements in groups.""" return self.groups.items().__iter__() - def add(self, **groups): - """Adds a new groups to the datastore.""" - for k, v in groups.items(): - if k in self.groups.keys(): - raise ValueError( - "Key already exist in the datastore, use `update` instead if intending to overwrite past data." - ) - self.groups[k] = v - def update(self, **groups): """Updates groups in the datastore, overwriting past values.""" for k, v in groups.items(): self.groups[k] = v + def add(self, **groups): + """Adds a new groups to the datastore.""" + + existing_keys = set(groups.keys()).intersection(set(self.groups.keys())) + if existing_keys: + raise ValueError( + f"Keys {existing_keys} already exist in the datastore, use `update` instead if intending to overwrite past data." + ) + + self.update(**groups) + # %% From b76239eb56ff68de1163e3a2c74f04d8f901ef9c Mon Sep 17 00:00:00 2001 From: yhteoh Date: Mon, 6 Oct 2025 08:50:32 -0400 Subject: [PATCH 15/48] [gitignore] updated gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index c6e81a6..6b99cc7 100644 --- a/.gitignore +++ b/.gitignore @@ -174,3 +174,4 @@ cython_debug/ *.h5 *.code-workspace .pre-commit-config.yaml +_scripts From 453bf96765036c0ca4c026102cccd9c951fdf691 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Mon, 6 Oct 2025 08:51:14 -0400 Subject: [PATCH 16/48] [fix] Saving and loading of optional datasets --- src/oqd_dataschema/datastore.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/oqd_dataschema/datastore.py b/src/oqd_dataschema/datastore.py index 13073b7..d11aed3 100644 --- a/src/oqd_dataschema/datastore.py +++ b/src/oqd_dataschema/datastore.py @@ -44,7 +44,7 @@ class Datastore(BaseModel, extra="forbid"): attrs (Attrs): attributes of the datastore. """ - groups: Dict[str, Any] + groups: Dict[str, Any] = {} attrs: Attrs = {} @@ -91,7 +91,7 @@ def _dump_group(self, h5datastore, gkey, group): # dump group data for dkey, dataset in group.__dict__.items(): - if dkey in ["attr", "class_"]: + if dkey in ["attrs", "class_"]: continue # if group field contain dictionary of Dataset @@ -105,9 +105,15 @@ def _dump_group(self, h5datastore, gkey, group): def _dump_dataset(self, h5group, dkey, dataset): """Helper function for dumping Dataset.""" - if not isinstance(dataset, Dataset): + + if dataset is not None and not isinstance(dataset, Dataset): raise ValueError("Group data field is not a Dataset.") + # handle optional dataset + if dataset is None: + h5_dataset = h5group.create_dataset(dkey, data=h5py.Empty("f")) + return + # dtype str converted to bytes when dumped (h5 compatibility) if dataset.dtype in "str": h5_dataset = h5group.create_dataset( @@ -139,7 +145,7 @@ def model_dump_hdf5(self, filepath: pathlib.Path, mode: Literal["w", "a"] = "w") # dump each group for gkey, group in self.groups.items(): - if gkey in ["attr", "class_"]: + if gkey in ["attrs", "class_"]: continue self._dump_group(f, gkey, group) @@ -163,6 +169,9 @@ def model_validate_hdf5(cls, filepath: pathlib.Path): if dkey in ("attrs", "class_"): continue + if group.__dict__[dkey] is None: + continue + # load Dataset data if isinstance(group.__dict__[dkey], Dataset): group.__dict__[dkey].data = np.array(f[gkey][dkey][()]).astype( From 3c425278c3bc697bbf5f9599761b2176d0e3fc2c Mon Sep 17 00:00:00 2001 From: yhteoh Date: Mon, 6 Oct 2025 08:59:31 -0400 Subject: [PATCH 17/48] [fix] protected attrs should be _group_schema --- src/oqd_dataschema/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/oqd_dataschema/base.py b/src/oqd_dataschema/base.py index 767382d..b0ce63e 100644 --- a/src/oqd_dataschema/base.py +++ b/src/oqd_dataschema/base.py @@ -69,7 +69,7 @@ def names(cls): ######################################################################################## -invalid_attrs = ["_datastore_signature", "_group_json"] +invalid_attrs = ["_datastore_signature", "_group_schema"] def _valid_attr_key(value): From 0c9ecbf8cec8439c5c057af7ce8788826b8e61b4 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Mon, 6 Oct 2025 09:11:52 -0400 Subject: [PATCH 18/48] [test] Added test for unmatched flex shape dataset --- tests/test_dataset.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index a1a09a0..ad9fdfa 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -79,18 +79,21 @@ def test_dtype_mutation(self): class TestDatasetShape: @pytest.mark.xfail(raises=ValueError) @pytest.mark.parametrize( - "shape", + ("shape", "data_shape"), [ - (0,), - (1,), - (99,), - (1, 1), + ((0,), (100,)), + ((1,), (100,)), + ((99,), (100,)), + ((1, 1), (100,)), + ((100, None), (100,)), + ((None, None), (100,)), + ((None, 100), (100,)), ], ) - def test_unmatched_shape_data(self, shape): - ds = Dataset(dtype="float64", shape=(100,)) + def test_unmatched_shape_data(self, shape, data_shape): + ds = Dataset(dtype="float64", shape=shape) - data = np.random.rand(*shape) + data = np.random.rand(*data_shape) ds.data = data @pytest.mark.parametrize( From caa91af7b02d2288fcc3e4961c0414897059efd6 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Mon, 6 Oct 2025 12:40:25 -0400 Subject: [PATCH 19/48] [refactor, feat] refactored package layout and added support for Tables in Datastore --- src/oqd_dataschema/__init__.py | 9 +- src/oqd_dataschema/base.py | 299 +------------------------------- src/oqd_dataschema/dataset.py | 194 +++++++++++++++++++++ src/oqd_dataschema/datastore.py | 92 +++++++--- src/oqd_dataschema/group.py | 206 ++++++++++++++++++++++ src/oqd_dataschema/groups.py | 61 ------- src/oqd_dataschema/table.py | 160 +++++++++++++++++ src/oqd_dataschema/utils.py | 18 +- tests/test_dataset.py | 3 +- tests/test_datastore.py | 3 +- tests/test_groupregistry.py | 2 +- 11 files changed, 661 insertions(+), 386 deletions(-) create mode 100644 src/oqd_dataschema/dataset.py create mode 100644 src/oqd_dataschema/group.py delete mode 100644 src/oqd_dataschema/groups.py create mode 100644 src/oqd_dataschema/table.py diff --git a/src/oqd_dataschema/__init__.py b/src/oqd_dataschema/__init__.py index 74c42db..df34151 100644 --- a/src/oqd_dataschema/__init__.py +++ b/src/oqd_dataschema/__init__.py @@ -12,14 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import CastDataset, Dataset, GroupBase, GroupRegistry, condataset +from .dataset import CastDataset, Dataset, condataset from .datastore import Datastore -from .groups import ( +from .group import ( ExpectationValueDataGroup, + GroupBase, + GroupRegistry, MeasurementOutcomesDataGroup, OQDTestbenchDataGroup, SinaraRawDataGroup, ) +from .table import Table, contable ######################################################################################## @@ -35,4 +38,6 @@ "SinaraRawDataGroup", "condataset", "CastDataset", + "Table", + "contable", ] diff --git a/src/oqd_dataschema/base.py b/src/oqd_dataschema/base.py index b0ce63e..d91e7b8 100644 --- a/src/oqd_dataschema/base.py +++ b/src/oqd_dataschema/base.py @@ -13,30 +13,20 @@ # limitations under the License. # %% -import typing -import warnings from enum import Enum -from types import NoneType -from typing import Annotated, Any, ClassVar, Literal, Optional, Sequence, Tuple, Union +from typing import Annotated, Optional, Union import numpy as np from pydantic import ( - AfterValidator, - BaseModel, BeforeValidator, - ConfigDict, - Discriminator, - Field, - TypeAdapter, - field_validator, - model_validator, ) -from .utils import _flex_shape_equal, _validator_from_condition - ######################################################################################## -__all__ = ["GroupBase", "Dataset", "GroupRegistry", "condataset", "CastDataset"] +__all__ = [ + "Attrs", + "DTypes", +] ######################################################################################## @@ -85,282 +75,3 @@ def _valid_attr_key(value): Union[int, float, str, complex], ] ] - -######################################################################################## - - -class Dataset(BaseModel, extra="forbid"): - """ - Schema representation for a dataset object to be saved within an HDF5 file. - - Attributes: - dtype: The datatype of the dataset, such as `int32`, `float32`, `int64`, `float64`, etc. - Types are inferred from the `data` attribute if provided. - shape: The shape of the dataset. - data: The numpy ndarray of the data, from which `dtype` and `shape` are inferred. - - attrs: A dictionary of attributes to append to the dataset. - - Example: - ``` - dataset = Dataset(data=np.array([1, 2, 3, 4])) - - dataset = Dataset(dtype='int64', shape=[4,]) - dataset.data = np.array([1, 2, 3, 4]) - ``` - """ - - dtype: Optional[Literal[DTypes.names()]] = None # type: ignore - shape: Optional[Tuple[Union[int, None], ...]] = None - data: Optional[Any] = Field(default=None, exclude=True) - - attrs: Attrs = {} - - model_config = ConfigDict( - use_enum_values=False, arbitrary_types_allowed=True, validate_assignment=True - ) - - @field_validator("data", mode="before") - @classmethod - def validate_and_update(cls, value): - # check if data exist - if value is None: - return value - - # check if data is a numpy array - if not isinstance(value, np.ndarray): - raise TypeError("`data` must be a numpy.ndarray.") - - return value - - @model_validator(mode="after") - def validate_data_matches_shape_dtype(self): - """Ensure that `data` matches `dtype` and `shape`.""" - - # check if data exist - if self.data is None: - return self - - # check if dtype matches data - if ( - self.dtype is not None - and type(self.data.dtype) is not DTypes.get(self.dtype).value - ): - raise ValueError( - f"Expected data dtype `{self.dtype}`, but got `{self.data.dtype.name}`." - ) - - # check if shape mataches data - if self.shape is not None and not _flex_shape_equal( - self.data.shape, self.shape - ): - raise ValueError(f"Expected shape {self.shape}, but got {self.data.shape}.") - - # reassign dtype if it is None - if self.dtype != DTypes(type(self.data.dtype)).name.lower(): - self.dtype = DTypes(type(self.data.dtype)).name.lower() - - # resassign shape to concrete value if it is None or a flexible shape - if self.shape != self.data.shape: - self.shape = self.data.shape - - return self - - @classmethod - def cast(cls, data): - if isinstance(data, np.ndarray): - return cls(data=data) - return data - - def __getitem__(self, idx): - return self.data[idx] - - @classmethod - def _is_dataset_type(cls, type_): - return type_ == cls or ( - typing.get_origin(type_) is Annotated and type_.__origin__ is cls - ) - - -CastDataset = Annotated[Dataset, BeforeValidator(Dataset.cast)] - -######################################################################################## - - -@_validator_from_condition -def _constrain_dtype(dataset, *, dtype_constraint=None): - """Constrains the dtype of a dataset""" - if (not isinstance(dtype_constraint, str)) and isinstance( - dtype_constraint, Sequence - ): - dtype_constraint = set(dtype_constraint) - elif isinstance(dtype_constraint, str): - dtype_constraint = {dtype_constraint} - - if dtype_constraint and dataset.dtype not in dtype_constraint: - raise ValueError( - f"Expected dtype to be of type one of {dtype_constraint}, but got {dataset.dtype}." - ) - - -@_validator_from_condition -def _constraint_dim(dataset, *, min_dim=None, max_dim=None): - """Constrains the dimension of a dataset""" - if min_dim is not None and max_dim is not None and min_dim > max_dim: - raise ValueError("Impossible to satisfy dimension constraints on dataset.") - - min_dim = 0 if min_dim is None else min_dim - - dims = len(dataset.shape) - - if dims < min_dim or (max_dim is not None and dims > max_dim): - raise ValueError( - f"Expected {min_dim} <= dimension of shape{f' <= {max_dim}'}, but got shape = {dataset.shape}." - ) - - -@_validator_from_condition -def _constraint_shape(dataset, *, shape_constraint=None): - """Constrains the shape of a dataset""" - if shape_constraint and not _flex_shape_equal(shape_constraint, dataset.shape): - raise ValueError( - f"Expected shape to be {shape_constraint}, but got {dataset.shape}." - ) - - -def condataset( - *, shape_constraint=None, dtype_constraint=None, min_dim=None, max_dim=None -): - """Implements dtype, dimension and shape constrains on the dataset.""" - return Annotated[ - CastDataset, - AfterValidator(_constrain_dtype(dtype_constraint=dtype_constraint)), - AfterValidator(_constraint_dim(min_dim=min_dim, max_dim=max_dim)), - AfterValidator(_constraint_shape(shape_constraint=shape_constraint)), - ] - - -######################################################################################## - - -class GroupBase(BaseModel, extra="forbid"): - """ - Schema representation for a group object within an HDF5 file. - - Each grouping of data should be defined as a subclass of `Group`, and specify the datasets that it will contain. - This base object only has attributes, `attrs`, which are associated to the HDF5 group. - - Attributes: - attrs: A dictionary of attributes to append to the dataset. - - Example: - ``` - group = Group(attrs={'version': 2, 'date': '2025-01-01'}) - ``` - """ - - attrs: Attrs = {} - - @classmethod - def _is_allowed_field_type(cls, v): - is_dataset = Dataset._is_dataset_type(v) - - is_annotated_dataset = typing.get_origin( - v - ) is Annotated and Dataset._is_dataset_type(v.__origin__) - - is_optional_dataset = typing.get_origin(v) is Union and ( - (v.__args__[0] == NoneType and Dataset._is_dataset_type(v.__args__[1])) - or (v.__args__[1] == NoneType and Dataset._is_dataset_type(v.__args__[0])) - ) - - is_dict_dataset = ( - typing.get_origin(v) is dict - and v.__args__[0] is str - and Dataset._is_dataset_type(v.__args__[1]) - ) - - return ( - is_dataset or is_annotated_dataset or is_optional_dataset or is_dict_dataset - ) - - @classmethod - def _is_classvar(cls, v): - return v is ClassVar or typing.get_origin(v) is ClassVar - - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - - for k, v in cls.__annotations__.items(): - if k in ["class_", "attrs"]: - raise AttributeError( - "`class_` and `attrs` attribute should not be set manually." - ) - - if cls._is_classvar(v): - continue - - if not cls._is_allowed_field_type(v): - raise TypeError( - "All fields of `GroupBase` have to be of type `Dataset`." - ) - - cls.__annotations__["class_"] = Literal[cls.__name__] - setattr(cls, "class_", cls.__name__) - - # Auto-register new group types - GroupRegistry.register(cls) - - -######################################################################################## - - -class MetaGroupRegistry(type): - """ - Metaclass for the GroupRegistry - """ - - def __new__(cls, clsname, superclasses, attributedict): - attributedict["groups"] = dict() - return super().__new__(cls, clsname, superclasses, attributedict) - - def register(cls, group): - """Registers a group into the GroupRegistry.""" - if not issubclass(group, GroupBase): - raise TypeError("You may only register subclasses of GroupBase.") - - if group.__name__ in cls.groups.keys(): - warnings.warn( - f"Overwriting previously registered `{group.__name__}` group of the same name.", - UserWarning, - stacklevel=2, - ) - - cls.groups[group.__name__] = group - - def clear(cls): - """Clear all registered types (useful for testing)""" - cls.groups.clear() - - @property - def union(cls): - """Get the current Union of all registered types""" - return Annotated[ - Union[tuple(cls.groups.values())], Discriminator(discriminator="class_") - ] - - @property - def adapter(cls): - """Get TypeAdapter for current registered types""" - return TypeAdapter(cls.union) - - -class GroupRegistry(metaclass=MetaGroupRegistry): - """ - Represents the GroupRegistry - """ - - pass - - -# %% diff --git a/src/oqd_dataschema/dataset.py b/src/oqd_dataschema/dataset.py new file mode 100644 index 0000000..56e4d67 --- /dev/null +++ b/src/oqd_dataschema/dataset.py @@ -0,0 +1,194 @@ +# Copyright 2024-2025 Open Quantum Design + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% +import typing +from typing import Annotated, Any, Literal, Optional, Sequence, Tuple, Union + +import numpy as np +from pydantic import ( + AfterValidator, + BaseModel, + BeforeValidator, + ConfigDict, + Field, + field_validator, + model_validator, +) + +from oqd_dataschema.base import Attrs, DTypes + +from .utils import _flex_shape_equal, _validator_from_condition + +######################################################################################## + +__all__ = [ + "Dataset", + "CastDataset", + "condataset", +] + +######################################################################################## + + +class Dataset(BaseModel, extra="forbid"): + """ + Schema representation for a dataset object to be saved within an HDF5 file. + + Attributes: + dtype: The datatype of the dataset, such as `int32`, `float32`, `int64`, `float64`, etc. + Types are inferred from the `data` attribute if provided. + shape: The shape of the dataset. + data: The numpy ndarray of the data, from which `dtype` and `shape` are inferred. + + attrs: A dictionary of attributes to append to the dataset. + + Example: + ``` + dataset = Dataset(data=np.array([1, 2, 3, 4])) + + dataset = Dataset(dtype='int64', shape=[4,]) + dataset.data = np.array([1, 2, 3, 4]) + ``` + """ + + dtype: Optional[Literal[DTypes.names()]] = None # type: ignore + shape: Optional[Tuple[Union[int, None], ...]] = None + data: Optional[Any] = Field(default=None, exclude=True) + + attrs: Attrs = {} + + model_config = ConfigDict( + use_enum_values=False, arbitrary_types_allowed=True, validate_assignment=True + ) + + @field_validator("data", mode="before") + @classmethod + def validate_and_update(cls, value): + # check if data exist + if value is None: + return value + + # check if data is a numpy array + if not isinstance(value, np.ndarray): + raise TypeError("`data` must be a numpy.ndarray.") + + return value + + @model_validator(mode="after") + def validate_data_matches_shape_dtype(self): + """Ensure that `data` matches `dtype` and `shape`.""" + + # check if data exist + if self.data is None: + return self + + # check if dtype matches data + if ( + self.dtype is not None + and type(self.data.dtype) is not DTypes.get(self.dtype).value + ): + raise ValueError( + f"Expected data dtype `{self.dtype}`, but got `{self.data.dtype.name}`." + ) + + # check if shape mataches data + if self.shape is not None and not _flex_shape_equal( + self.data.shape, self.shape + ): + raise ValueError(f"Expected shape {self.shape}, but got {self.data.shape}.") + + # reassign dtype if it is None + if self.dtype != DTypes(type(self.data.dtype)).name.lower(): + self.dtype = DTypes(type(self.data.dtype)).name.lower() + + # resassign shape to concrete value if it is None or a flexible shape + if self.shape != self.data.shape: + self.shape = self.data.shape + + return self + + @classmethod + def cast(cls, data): + if isinstance(data, np.ndarray): + return cls(data=data) + return data + + def __getitem__(self, idx): + return self.data[idx] + + @classmethod + def _is_dataset_type(cls, type_): + return type_ == cls or ( + typing.get_origin(type_) is Annotated and type_.__origin__ is cls + ) + + +CastDataset = Annotated[Dataset, BeforeValidator(Dataset.cast)] + + +######################################################################################## + + +@_validator_from_condition +def _constrain_dtype(dataset, *, dtype_constraint=None): + """Constrains the dtype of a dataset""" + if (not isinstance(dtype_constraint, str)) and isinstance( + dtype_constraint, Sequence + ): + dtype_constraint = set(dtype_constraint) + elif isinstance(dtype_constraint, str): + dtype_constraint = {dtype_constraint} + + if dtype_constraint and dataset.dtype not in dtype_constraint: + raise ValueError( + f"Expected dtype to be of type one of {dtype_constraint}, but got {dataset.dtype}." + ) + + +@_validator_from_condition +def _constraint_dim(dataset, *, min_dim=None, max_dim=None): + """Constrains the dimension of a dataset""" + if min_dim is not None and max_dim is not None and min_dim > max_dim: + raise ValueError("Impossible to satisfy dimension constraints on dataset.") + + min_dim = 0 if min_dim is None else min_dim + + dims = len(dataset.shape) + + if dims < min_dim or (max_dim is not None and dims > max_dim): + raise ValueError( + f"Expected {min_dim} <= dimension of shape{f' <= {max_dim}'}, but got shape = {dataset.shape}." + ) + + +@_validator_from_condition +def _constraint_shape(dataset, *, shape_constraint=None): + """Constrains the shape of a dataset""" + if shape_constraint and not _flex_shape_equal(shape_constraint, dataset.shape): + raise ValueError( + f"Expected shape to be {shape_constraint}, but got {dataset.shape}." + ) + + +def condataset( + *, shape_constraint=None, dtype_constraint=None, min_dim=None, max_dim=None +): + """Implements dtype, dimension and shape constrains on the Dataset.""" + return Annotated[ + CastDataset, + AfterValidator(_constrain_dtype(dtype_constraint=dtype_constraint)), + AfterValidator(_constraint_dim(min_dim=min_dim, max_dim=max_dim)), + AfterValidator(_constraint_shape(shape_constraint=shape_constraint)), + ] diff --git a/src/oqd_dataschema/datastore.py b/src/oqd_dataschema/datastore.py index d11aed3..f581adb 100644 --- a/src/oqd_dataschema/datastore.py +++ b/src/oqd_dataschema/datastore.py @@ -25,7 +25,10 @@ field_validator, ) -from oqd_dataschema.base import Attrs, Dataset, DTypes, GroupBase, GroupRegistry +from oqd_dataschema.base import Attrs, DTypes +from oqd_dataschema.dataset import Dataset +from oqd_dataschema.group import GroupBase, GroupRegistry +from oqd_dataschema.table import Table ######################################################################################## @@ -106,21 +109,45 @@ def _dump_group(self, h5datastore, gkey, group): def _dump_dataset(self, h5group, dkey, dataset): """Helper function for dumping Dataset.""" - if dataset is not None and not isinstance(dataset, Dataset): - raise ValueError("Group data field is not a Dataset.") + if ( + dataset is not None + and not isinstance(dataset, Dataset) + and not isinstance(dataset, Table) + ): + raise ValueError("Group data field is not a Dataset or a Table.") # handle optional dataset if dataset is None: h5_dataset = h5group.create_dataset(dkey, data=h5py.Empty("f")) return - # dtype str converted to bytes when dumped (h5 compatibility) - if dataset.dtype in "str": + if isinstance(dataset, Dataset): + # dtype str converted to bytes when dumped (h5 compatibility) + np_dtype = ( + np.dtypes.BytesDType + if dataset.dtype == "str" + else DTypes.get(dataset.dtype).value + ) + + h5_dataset = h5group.create_dataset( + dkey, data=dataset.data.astype(np_dtype) + ) + + if isinstance(dataset, Table): + # dtype str converted to bytes when dumped (h5 compatibility) + np_dtype = np.dtype( + [ + (k, np.empty(0, dtype=v).astype(np.dtypes.BytesDType).dtype) + if dict(dataset.columns)[k] == "str" + else (k, v) + for k, (v, _) in dataset.data.dtype.fields.items() + ] + ) + h5_dataset = h5group.create_dataset( - dkey, data=dataset.data.astype(np.dtypes.BytesDType) + dkey, + data=dataset.data.astype(np_dtype), ) - else: - h5_dataset = h5group.create_dataset(dkey, data=dataset.data) # dump dataset attributes for akey, attr in dataset.attrs.items(): @@ -150,6 +177,35 @@ def model_dump_hdf5(self, filepath: pathlib.Path, mode: Literal["w", "a"] = "w") self._dump_group(f, gkey, group) + @classmethod + def _load_data(cls, group, h5group, dkey, ikey=None): + field = group.__dict__[ikey] if ikey else group.__dict__ + h5field = h5group[ikey] if ikey else h5group + + if isinstance(field[dkey], Dataset): + field[dkey].data = np.array(h5field[dkey][()]).astype( + DTypes.get(field[dkey].dtype).value + ) + return + if isinstance(field[dkey], Table): + np_dtype = np.dtype( + [ + ( + k, + np.empty(0, dtype=v).astype(np.dtypes.StrDType).dtype, + ) + if dict(field[dkey].columns)[k] == "str" + else (k, v) + for k, (v, _) in np.array(h5field[dkey][()]).dtype.fields.items() + ] + ) + field[dkey].data = np.array(h5field[dkey][()]).astype(np_dtype) + return + + raise ValueError( + "Attempted to load Group data field that is neither Dataset nor Table." + ) + @classmethod def model_validate_hdf5(cls, filepath: pathlib.Path): """ @@ -172,26 +228,14 @@ def model_validate_hdf5(cls, filepath: pathlib.Path): if group.__dict__[dkey] is None: continue - # load Dataset data - if isinstance(group.__dict__[dkey], Dataset): - group.__dict__[dkey].data = np.array(f[gkey][dkey][()]).astype( - DTypes.get(group.__dict__[dkey].dtype).value - ) - continue - - # load data for dict of Dataset + # load data for dict of Dataset or dict of Table if isinstance(group.__dict__[dkey], dict): for ddkey in group.__dict__[dkey]: - group.__dict__[dkey][ddkey].data = np.array( - f[gkey][dkey][ddkey][()] - ).astype( - DTypes.get(group.__dict__[dkey][ddkey].dtype).value - ) + cls._load_data(group, f[gkey], dkey=ddkey, ikey=dkey) continue - raise TypeError( - "Group data fields must be of type Dataset or dict of Dataset." - ) + # load Dataset or Table data + cls._load_data(group, f[gkey], dkey=dkey) return self diff --git a/src/oqd_dataschema/group.py b/src/oqd_dataschema/group.py new file mode 100644 index 0000000..dc88871 --- /dev/null +++ b/src/oqd_dataschema/group.py @@ -0,0 +1,206 @@ +# Copyright 2024-2025 Open Quantum Design + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing +import warnings +from types import NoneType +from typing import Annotated, ClassVar, Literal, Union + +from pydantic import ( + BaseModel, + Discriminator, + TypeAdapter, +) + +from oqd_dataschema.base import Attrs +from oqd_dataschema.dataset import CastDataset, Dataset +from oqd_dataschema.table import Table + +######################################################################################## + +__all__ = [ + "GroupBase", + "GroupRegistry", + "SinaraRawDataGroup", + "MeasurementOutcomesDataGroup", + "ExpectationValueDataGroup", + "OQDTestbenchDataGroup", +] + + +######################################################################################## + + +class GroupBase(BaseModel, extra="forbid"): + """ + Schema representation for a group object within an HDF5 file. + + Each grouping of data should be defined as a subclass of `Group`, and specify the datasets that it will contain. + This base object only has attributes, `attrs`, which are associated to the HDF5 group. + + Attributes: + attrs: A dictionary of attributes to append to the dataset. + + Example: + ``` + group = Group(attrs={'version': 2, 'date': '2025-01-01'}) + ``` + """ + + attrs: Attrs = {} + + @staticmethod + def _is_datafield_type(v): + return Dataset._is_dataset_type(v) or Table._is_table_type(v) + + @classmethod + def _is_allowed_field_type(cls, v): + is_datafield = cls._is_datafield_type(v) + + is_annotated_datafield = typing.get_origin( + v + ) is Annotated and cls._is_datafield_type(v.__origin__) + + is_optional_datafield = typing.get_origin(v) is Union and ( + (v.__args__[0] == NoneType and cls._is_datafield_type(v.__args__[1])) + or (v.__args__[1] == NoneType and cls._is_datafield_type(v.__args__[0])) + ) + + is_dict_datafield = ( + typing.get_origin(v) is dict + and v.__args__[0] is str + and cls._is_datafield_type(v.__args__[1]) + ) + + return ( + is_datafield + or is_annotated_datafield + or is_optional_datafield + or is_dict_datafield + ) + + @classmethod + def _is_classvar(cls, v): + return v is ClassVar or typing.get_origin(v) is ClassVar + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + + for k, v in cls.__annotations__.items(): + if k in ["class_", "attrs"]: + raise AttributeError( + "`class_` and `attrs` attribute should not be set manually." + ) + + if cls._is_classvar(v): + continue + + if not cls._is_allowed_field_type(v): + raise TypeError( + "All fields of `GroupBase` have to be of type `Dataset`." + ) + + cls.__annotations__["class_"] = Literal[cls.__name__] + setattr(cls, "class_", cls.__name__) + + # Auto-register new group types + GroupRegistry.register(cls) + + +######################################################################################## + + +class MetaGroupRegistry(type): + """ + Metaclass for the GroupRegistry + """ + + def __new__(cls, clsname, superclasses, attributedict): + attributedict["groups"] = dict() + return super().__new__(cls, clsname, superclasses, attributedict) + + def register(cls, group): + """Registers a group into the GroupRegistry.""" + if not issubclass(group, GroupBase): + raise TypeError("You may only register subclasses of GroupBase.") + + if group.__name__ in cls.groups.keys(): + warnings.warn( + f"Overwriting previously registered `{group.__name__}` group of the same name.", + UserWarning, + stacklevel=2, + ) + + cls.groups[group.__name__] = group + + def clear(cls): + """Clear all registered types (useful for testing)""" + cls.groups.clear() + + @property + def union(cls): + """Get the current Union of all registered types""" + return Annotated[ + Union[tuple(cls.groups.values())], Discriminator(discriminator="class_") + ] + + @property + def adapter(cls): + """Get TypeAdapter for current registered types""" + return TypeAdapter(cls.union) + + +class GroupRegistry(metaclass=MetaGroupRegistry): + """ + Represents the GroupRegistry + """ + + pass + + +######################################################################################## + + +class SinaraRawDataGroup(GroupBase): + """ + Example `Group` for raw data from the Sinara real-time control system. + This is a placeholder for demonstration and development. + """ + + camera_images: CastDataset + + +class MeasurementOutcomesDataGroup(GroupBase): + """ + Example `Group` for processed data classifying the readout of the state. + This is a placeholder for demonstration and development. + """ + + outcomes: CastDataset + + +class ExpectationValueDataGroup(GroupBase): + """ + Example `Group` for processed data calculating the expectation values. + This is a placeholder for demonstration and development. + """ + + expectation_value: CastDataset + + +class OQDTestbenchDataGroup(GroupBase): + """ """ + + time: CastDataset + voltages: CastDataset diff --git a/src/oqd_dataschema/groups.py b/src/oqd_dataschema/groups.py deleted file mode 100644 index b244f62..0000000 --- a/src/oqd_dataschema/groups.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2024-2025 Open Quantum Design - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from oqd_dataschema.base import CastDataset, GroupBase - -######################################################################################## - -__all__ = [ - "SinaraRawDataGroup", - "MeasurementOutcomesDataGroup", - "ExpectationValueDataGroup", - "OQDTestbenchDataGroup", -] - -######################################################################################## - - -class SinaraRawDataGroup(GroupBase): - """ - Example `Group` for raw data from the Sinara real-time control system. - This is a placeholder for demonstration and development. - """ - - camera_images: CastDataset - - -class MeasurementOutcomesDataGroup(GroupBase): - """ - Example `Group` for processed data classifying the readout of the state. - This is a placeholder for demonstration and development. - """ - - outcomes: CastDataset - - -class ExpectationValueDataGroup(GroupBase): - """ - Example `Group` for processed data calculating the expectation values. - This is a placeholder for demonstration and development. - """ - - expectation_value: CastDataset - - -class OQDTestbenchDataGroup(GroupBase): - """ """ - - time: CastDataset - voltages: CastDataset diff --git a/src/oqd_dataschema/table.py b/src/oqd_dataschema/table.py new file mode 100644 index 0000000..2f8b076 --- /dev/null +++ b/src/oqd_dataschema/table.py @@ -0,0 +1,160 @@ +# Copyright 2024-2025 Open Quantum Design + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import typing +from types import MappingProxyType +from typing import Annotated, Any, List, Literal, Optional, Tuple, Union + +import numpy as np +from pydantic import ( + AfterValidator, + BaseModel, + BeforeValidator, + ConfigDict, + Field, + field_validator, + model_validator, +) + +from oqd_dataschema.base import Attrs, DTypes +from oqd_dataschema.dataset import _constraint_dim, _constraint_shape +from oqd_dataschema.utils import _flex_shape_equal, _is_list_unique + +######################################################################################## + +__all__ = [ + "Table", + "CastTable", + "contable", +] + +######################################################################################## + + +Column = Tuple[str, Optional[Literal[DTypes.names()]]] + + +class Table(BaseModel, extra="forbid"): + columns: List[Column] # type: ignore + shape: Optional[Tuple[Union[int, None], ...]] = None + data: Optional[Any] = Field(default=None, exclude=True) + + attrs: Attrs = {} + + model_config = ConfigDict( + use_enum_values=False, arbitrary_types_allowed=True, validate_assignment=True + ) + + @field_validator("columns", mode="before") + @classmethod + def validate_unique(cls, value): + column_names = [c[0] for c in value] + + is_unique, duplicates = _is_list_unique(column_names) + if not is_unique: + raise ValueError(f"More than one column with the same name ({duplicates}).") + + return value + + @field_validator("data", mode="before") + @classmethod + def validate_and_update(cls, value): + # check if data exist + if value is None: + return value + + # check if data is a numpy array + if not isinstance(value, np.ndarray): + raise TypeError("`data` must be a numpy.ndarray.") + + if not isinstance(value.dtype.fields, MappingProxyType): + raise TypeError("dtype of data must be a structured dtype.") + + return value + + @model_validator(mode="after") + def validate_data_matches_shape_dtype(self): + """Ensure that `data` matches `dtype` and `shape`.""" + + # check if data exist + if self.data is None: + return self + + if set(self.data.dtype.fields.keys()) != set([c[0] for c in self.columns]): + raise ValueError("Fields of data do not match expected field for Table.") + + # check if dtype matches data + for k, v in self.data.dtype.fields.items(): + if ( + dict(self.columns)[k] is not None + and type(v[0]) is not DTypes.get(dict(self.columns)[k]).value + ): + raise ValueError( + f"Expected data dtype `{dict(self.columns)[k]}`, but got `{v[0].name}`." + ) + + # check if shape mataches data + if self.shape is not None and not _flex_shape_equal( + self.data.shape, self.shape + ): + raise ValueError(f"Expected shape {self.shape}, but got {self.data.shape}.") + + # reassign dtype if it is None + for n, (k, v) in enumerate(self.columns): + if v != DTypes(type(self.data.dtype.fields[k][0])).name.lower(): + self.columns[n] = ( + k, + DTypes(type(self.data.dtype.fields[k][0])).name.lower(), + ) + + # resassign shape to concrete value if it is None or a flexible shape + if self.shape != self.data.shape: + self.shape = self.data.shape + + return self + + @classmethod + def cast(cls, data): + if isinstance(data, np.ndarray): + if not isinstance(data.dtype.fields, MappingProxyType): + raise TypeError("dtype of data must be a structured dtype.") + + columns = [ + (k, DTypes(type(v)).name.lower()) + for k, (v, _) in data.dtype.fields.items() + ] + + return cls(columns=columns, data=data) + return data + + @classmethod + def _is_table_type(cls, type_): + return type_ == cls or ( + typing.get_origin(type_) is Annotated and type_.__origin__ is cls + ) + + +CastTable = Annotated[Table, BeforeValidator(Table.cast)] + +######################################################################################## + + +def contable(*, shape_constraint=None, min_dim=None, max_dim=None): + """Implements dtype, dimension and shape constrains on the Table.""" + return Annotated[ + Table, + AfterValidator(_constraint_dim(min_dim=min_dim, max_dim=max_dim)), + AfterValidator(_constraint_shape(shape_constraint=shape_constraint)), + ] diff --git a/src/oqd_dataschema/utils.py b/src/oqd_dataschema/utils.py index 74ae782..57a7747 100644 --- a/src/oqd_dataschema/utils.py +++ b/src/oqd_dataschema/utils.py @@ -16,7 +16,7 @@ ######################################################################################## -__all__ = ["_flex_shape_equal", "_validator_from_condition"] +__all__ = ["_flex_shape_equal", "_validator_from_condition", "_is_list_unique"] ######################################################################################## @@ -44,3 +44,19 @@ def _wrapped_condition(model): return _wrapped_condition return _wrapped_validator + + +def _is_list_unique(data): + seen = set() + duplicates = set() + for element in data: + if element in duplicates: + continue + + if element in seen: + duplicates.add(element) + continue + + seen.add(element) + + return (duplicates == set(), duplicates) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index ad9fdfa..5ccd65d 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -18,7 +18,8 @@ import pytest from pydantic import TypeAdapter -from oqd_dataschema.base import CastDataset, Dataset, DTypes, condataset +from oqd_dataschema import CastDataset, Dataset, condataset +from oqd_dataschema.base import DTypes ######################################################################################## diff --git a/tests/test_datastore.py b/tests/test_datastore.py index c39f49e..8499609 100644 --- a/tests/test_datastore.py +++ b/tests/test_datastore.py @@ -19,8 +19,7 @@ import numpy as np import pytest -from oqd_dataschema import Datastore, GroupBase -from oqd_dataschema.base import Dataset +from oqd_dataschema import Dataset, Datastore, GroupBase # %% diff --git a/tests/test_groupregistry.py b/tests/test_groupregistry.py index 1425781..989ff4b 100644 --- a/tests/test_groupregistry.py +++ b/tests/test_groupregistry.py @@ -17,7 +17,7 @@ import pytest -from oqd_dataschema.base import ( +from oqd_dataschema import ( CastDataset, Dataset, GroupBase, From af44790186e001b0b0370b39180d2cd864c2f8ab Mon Sep 17 00:00:00 2001 From: yhteoh Date: Mon, 6 Oct 2025 14:17:07 -0400 Subject: [PATCH 20/48] [feat] Table supports casting to and from pandas.DataFrame --- pyproject.toml | 6 ++- src/oqd_dataschema/table.py | 46 +++++++++++++++++++- uv.lock | 87 ++++++++++++++++++++++++++++++++++++- 3 files changed, 134 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7ccc3bd..6e17436 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,11 @@ classifiers = [ "Programming Language :: Python :: 3.12", ] -dependencies = ["h5py>=3.14.0", "pydantic>=2.10.6"] +dependencies = [ + "h5py>=3.14.0", + "pandas>=2.3.3", + "pydantic>=2.10.6", +] [project.optional-dependencies] docs = [ diff --git a/src/oqd_dataschema/table.py b/src/oqd_dataschema/table.py index 2f8b076..ddbcd18 100644 --- a/src/oqd_dataschema/table.py +++ b/src/oqd_dataschema/table.py @@ -18,6 +18,7 @@ from typing import Annotated, Any, List, Literal, Optional, Tuple, Union import numpy as np +import pandas as pd from pydantic import ( AfterValidator, BaseModel, @@ -68,6 +69,41 @@ def validate_unique(cls, value): return value + @property + def pd(self): + if len(self.shape) > 1: + raise ValueError( + "Conversion to pandas DataFrame only supported on 1D Table." + ) + return pd.DataFrame( + data=self.data, columns=[c[0] for c in self.columns] + ).astype({k: v for k, v in self.columns}) + + @staticmethod + def _pd_to_np(df): + np_dtype = [] + for k, v in df.dtypes.items(): + if type(v) is not np.dtypes.ObjectDType: + field_np_dtype = (k, v) + np_dtype.append(field_np_dtype) + continue + + # Check if column of object dtype is actually str dtype + if (np.vectorize(lambda x: isinstance(x, str))(df[k].to_numpy())).all(): + dt = df[k].to_numpy().astype(np.dtypes.StrDType).dtype + field_np_dtype = (k, dt) + + np_dtype.append(field_np_dtype) + continue + + raise ValueError(f"Unsupported datatype for column {k}") + + return np.rec.fromarrays( + df.to_numpy().transpose(), + names=[dt[0] for dt in np_dtype], + formats=[dt[1] for dt in np_dtype], + ).astype(np.dtype(np_dtype)) + @field_validator("data", mode="before") @classmethod def validate_and_update(cls, value): @@ -76,8 +112,11 @@ def validate_and_update(cls, value): return value # check if data is a numpy array - if not isinstance(value, np.ndarray): - raise TypeError("`data` must be a numpy.ndarray.") + if not isinstance(value, (np.ndarray, pd.DataFrame)): + raise TypeError("`data` must be a numpy.ndarray or pandas.DataFrame.") + + if isinstance(value, pd.DataFrame): + value = cls._pd_to_np(value) if not isinstance(value.dtype.fields, MappingProxyType): raise TypeError("dtype of data must be a structured dtype.") @@ -127,6 +166,9 @@ def validate_data_matches_shape_dtype(self): @classmethod def cast(cls, data): + if isinstance(data, pd.DataFrame): + data = cls._pd_to_np(data) + if isinstance(data, np.ndarray): if not isinstance(data.dtype.fields, MappingProxyType): raise TypeError("dtype of data must be a structured dtype.") diff --git a/uv.lock b/uv.lock index f651324..e015315 100644 --- a/uv.lock +++ b/uv.lock @@ -2,7 +2,8 @@ version = 1 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.14'", - "python_full_version >= '3.11' and python_full_version < '3.14'", + "python_full_version >= '3.12' and python_full_version < '3.14'", + "python_full_version == '3.11.*'", "python_full_version < '3.11'", ] @@ -656,7 +657,8 @@ version = "9.5.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ "python_full_version >= '3.14'", - "python_full_version >= '3.11' and python_full_version < '3.14'", + "python_full_version >= '3.12' and python_full_version < '3.14'", + "python_full_version == '3.11.*'", ] dependencies = [ { name = "colorama", marker = "python_full_version >= '3.11' and sys_platform == 'win32'" }, @@ -1428,6 +1430,7 @@ version = "0.1.0" source = { editable = "." } dependencies = [ { name = "h5py" }, + { name = "pandas" }, { name = "pydantic" }, ] @@ -1458,6 +1461,7 @@ requires-dist = [ { name = "mkdocs-material", marker = "extra == 'docs'" }, { name = "mkdocstrings", marker = "extra == 'docs'" }, { name = "mkdocstrings-python", marker = "extra == 'docs'" }, + { name = "pandas", specifier = ">=2.3.3" }, { name = "pydantic", specifier = ">=2.10.6" }, { name = "pymdown-extensions", marker = "extra == 'docs'" }, { name = "pytest", marker = "extra == 'tests'" }, @@ -1498,6 +1502,67 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/90/96/04b8e52da071d28f5e21a805b19cb9390aa17a47462ac87f5e2696b9566d/paginate-0.5.7-py2.py3-none-any.whl", hash = "sha256:b885e2af73abcf01d9559fd5216b57ef722f8c42affbb63942377668e35c7591", size = 13746 }, ] +[[package]] +name = "pandas" +version = "2.3.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "python-dateutil" }, + { name = "pytz" }, + { name = "tzdata" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/33/01/d40b85317f86cf08d853a4f495195c73815fdf205eef3993821720274518/pandas-2.3.3.tar.gz", hash = "sha256:e05e1af93b977f7eafa636d043f9f94c7ee3ac81af99c13508215942e64c993b", size = 4495223 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3d/f7/f425a00df4fcc22b292c6895c6831c0c8ae1d9fac1e024d16f98a9ce8749/pandas-2.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:376c6446ae31770764215a6c937f72d917f214b43560603cd60da6408f183b6c", size = 11555763 }, + { url = "https://files.pythonhosted.org/packages/13/4f/66d99628ff8ce7857aca52fed8f0066ce209f96be2fede6cef9f84e8d04f/pandas-2.3.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e19d192383eab2f4ceb30b412b22ea30690c9e618f78870357ae1d682912015a", size = 10801217 }, + { url = "https://files.pythonhosted.org/packages/1d/03/3fc4a529a7710f890a239cc496fc6d50ad4a0995657dccc1d64695adb9f4/pandas-2.3.3-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5caf26f64126b6c7aec964f74266f435afef1c1b13da3b0636c7518a1fa3e2b1", size = 12148791 }, + { url = "https://files.pythonhosted.org/packages/40/a8/4dac1f8f8235e5d25b9955d02ff6f29396191d4e665d71122c3722ca83c5/pandas-2.3.3-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dd7478f1463441ae4ca7308a70e90b33470fa593429f9d4c578dd00d1fa78838", size = 12769373 }, + { url = "https://files.pythonhosted.org/packages/df/91/82cc5169b6b25440a7fc0ef3a694582418d875c8e3ebf796a6d6470aa578/pandas-2.3.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:4793891684806ae50d1288c9bae9330293ab4e083ccd1c5e383c34549c6e4250", size = 13200444 }, + { url = "https://files.pythonhosted.org/packages/10/ae/89b3283800ab58f7af2952704078555fa60c807fff764395bb57ea0b0dbd/pandas-2.3.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:28083c648d9a99a5dd035ec125d42439c6c1c525098c58af0fc38dd1a7a1b3d4", size = 13858459 }, + { url = "https://files.pythonhosted.org/packages/85/72/530900610650f54a35a19476eca5104f38555afccda1aa11a92ee14cb21d/pandas-2.3.3-cp310-cp310-win_amd64.whl", hash = "sha256:503cf027cf9940d2ceaa1a93cfb5f8c8c7e6e90720a2850378f0b3f3b1e06826", size = 11346086 }, + { url = "https://files.pythonhosted.org/packages/c1/fa/7ac648108144a095b4fb6aa3de1954689f7af60a14cf25583f4960ecb878/pandas-2.3.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:602b8615ebcc4a0c1751e71840428ddebeb142ec02c786e8ad6b1ce3c8dec523", size = 11578790 }, + { url = "https://files.pythonhosted.org/packages/9b/35/74442388c6cf008882d4d4bdfc4109be87e9b8b7ccd097ad1e7f006e2e95/pandas-2.3.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8fe25fc7b623b0ef6b5009149627e34d2a4657e880948ec3c840e9402e5c1b45", size = 10833831 }, + { url = "https://files.pythonhosted.org/packages/fe/e4/de154cbfeee13383ad58d23017da99390b91d73f8c11856f2095e813201b/pandas-2.3.3-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b468d3dad6ff947df92dcb32ede5b7bd41a9b3cceef0a30ed925f6d01fb8fa66", size = 12199267 }, + { url = "https://files.pythonhosted.org/packages/bf/c9/63f8d545568d9ab91476b1818b4741f521646cbdd151c6efebf40d6de6f7/pandas-2.3.3-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b98560e98cb334799c0b07ca7967ac361a47326e9b4e5a7dfb5ab2b1c9d35a1b", size = 12789281 }, + { url = "https://files.pythonhosted.org/packages/f2/00/a5ac8c7a0e67fd1a6059e40aa08fa1c52cc00709077d2300e210c3ce0322/pandas-2.3.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1d37b5848ba49824e5c30bedb9c830ab9b7751fd049bc7914533e01c65f79791", size = 13240453 }, + { url = "https://files.pythonhosted.org/packages/27/4d/5c23a5bc7bd209231618dd9e606ce076272c9bc4f12023a70e03a86b4067/pandas-2.3.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:db4301b2d1f926ae677a751eb2bd0e8c5f5319c9cb3f88b0becbbb0b07b34151", size = 13890361 }, + { url = "https://files.pythonhosted.org/packages/8e/59/712db1d7040520de7a4965df15b774348980e6df45c129b8c64d0dbe74ef/pandas-2.3.3-cp311-cp311-win_amd64.whl", hash = "sha256:f086f6fe114e19d92014a1966f43a3e62285109afe874f067f5abbdcbb10e59c", size = 11348702 }, + { url = "https://files.pythonhosted.org/packages/9c/fb/231d89e8637c808b997d172b18e9d4a4bc7bf31296196c260526055d1ea0/pandas-2.3.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:6d21f6d74eb1725c2efaa71a2bfc661a0689579b58e9c0ca58a739ff0b002b53", size = 11597846 }, + { url = "https://files.pythonhosted.org/packages/5c/bd/bf8064d9cfa214294356c2d6702b716d3cf3bb24be59287a6a21e24cae6b/pandas-2.3.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3fd2f887589c7aa868e02632612ba39acb0b8948faf5cc58f0850e165bd46f35", size = 10729618 }, + { url = "https://files.pythonhosted.org/packages/57/56/cf2dbe1a3f5271370669475ead12ce77c61726ffd19a35546e31aa8edf4e/pandas-2.3.3-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ecaf1e12bdc03c86ad4a7ea848d66c685cb6851d807a26aa245ca3d2017a1908", size = 11737212 }, + { url = "https://files.pythonhosted.org/packages/e5/63/cd7d615331b328e287d8233ba9fdf191a9c2d11b6af0c7a59cfcec23de68/pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b3d11d2fda7eb164ef27ffc14b4fcab16a80e1ce67e9f57e19ec0afaf715ba89", size = 12362693 }, + { url = "https://files.pythonhosted.org/packages/a6/de/8b1895b107277d52f2b42d3a6806e69cfef0d5cf1d0ba343470b9d8e0a04/pandas-2.3.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a68e15f780eddf2b07d242e17a04aa187a7ee12b40b930bfdd78070556550e98", size = 12771002 }, + { url = "https://files.pythonhosted.org/packages/87/21/84072af3187a677c5893b170ba2c8fbe450a6ff911234916da889b698220/pandas-2.3.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:371a4ab48e950033bcf52b6527eccb564f52dc826c02afd9a1bc0ab731bba084", size = 13450971 }, + { url = "https://files.pythonhosted.org/packages/86/41/585a168330ff063014880a80d744219dbf1dd7a1c706e75ab3425a987384/pandas-2.3.3-cp312-cp312-win_amd64.whl", hash = "sha256:a16dcec078a01eeef8ee61bf64074b4e524a2a3f4b3be9326420cabe59c4778b", size = 10992722 }, + { url = "https://files.pythonhosted.org/packages/cd/4b/18b035ee18f97c1040d94debd8f2e737000ad70ccc8f5513f4eefad75f4b/pandas-2.3.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:56851a737e3470de7fa88e6131f41281ed440d29a9268dcbf0002da5ac366713", size = 11544671 }, + { url = "https://files.pythonhosted.org/packages/31/94/72fac03573102779920099bcac1c3b05975c2cb5f01eac609faf34bed1ca/pandas-2.3.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:bdcd9d1167f4885211e401b3036c0c8d9e274eee67ea8d0758a256d60704cfe8", size = 10680807 }, + { url = "https://files.pythonhosted.org/packages/16/87/9472cf4a487d848476865321de18cc8c920b8cab98453ab79dbbc98db63a/pandas-2.3.3-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e32e7cc9af0f1cc15548288a51a3b681cc2a219faa838e995f7dc53dbab1062d", size = 11709872 }, + { url = "https://files.pythonhosted.org/packages/15/07/284f757f63f8a8d69ed4472bfd85122bd086e637bf4ed09de572d575a693/pandas-2.3.3-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:318d77e0e42a628c04dc56bcef4b40de67918f7041c2b061af1da41dcff670ac", size = 12306371 }, + { url = "https://files.pythonhosted.org/packages/33/81/a3afc88fca4aa925804a27d2676d22dcd2031c2ebe08aabd0ae55b9ff282/pandas-2.3.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:4e0a175408804d566144e170d0476b15d78458795bb18f1304fb94160cabf40c", size = 12765333 }, + { url = "https://files.pythonhosted.org/packages/8d/0f/b4d4ae743a83742f1153464cf1a8ecfafc3ac59722a0b5c8602310cb7158/pandas-2.3.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:93c2d9ab0fc11822b5eece72ec9587e172f63cff87c00b062f6e37448ced4493", size = 13418120 }, + { url = "https://files.pythonhosted.org/packages/4f/c7/e54682c96a895d0c808453269e0b5928a07a127a15704fedb643e9b0a4c8/pandas-2.3.3-cp313-cp313-win_amd64.whl", hash = "sha256:f8bfc0e12dc78f777f323f55c58649591b2cd0c43534e8355c51d3fede5f4dee", size = 10993991 }, + { url = "https://files.pythonhosted.org/packages/f9/ca/3f8d4f49740799189e1395812f3bf23b5e8fc7c190827d55a610da72ce55/pandas-2.3.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:75ea25f9529fdec2d2e93a42c523962261e567d250b0013b16210e1d40d7c2e5", size = 12048227 }, + { url = "https://files.pythonhosted.org/packages/0e/5a/f43efec3e8c0cc92c4663ccad372dbdff72b60bdb56b2749f04aa1d07d7e/pandas-2.3.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:74ecdf1d301e812db96a465a525952f4dde225fdb6d8e5a521d47e1f42041e21", size = 11411056 }, + { url = "https://files.pythonhosted.org/packages/46/b1/85331edfc591208c9d1a63a06baa67b21d332e63b7a591a5ba42a10bb507/pandas-2.3.3-cp313-cp313t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6435cb949cb34ec11cc9860246ccb2fdc9ecd742c12d3304989017d53f039a78", size = 11645189 }, + { url = "https://files.pythonhosted.org/packages/44/23/78d645adc35d94d1ac4f2a3c4112ab6f5b8999f4898b8cdf01252f8df4a9/pandas-2.3.3-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:900f47d8f20860de523a1ac881c4c36d65efcb2eb850e6948140fa781736e110", size = 12121912 }, + { url = "https://files.pythonhosted.org/packages/53/da/d10013df5e6aaef6b425aa0c32e1fc1f3e431e4bcabd420517dceadce354/pandas-2.3.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:a45c765238e2ed7d7c608fc5bc4a6f88b642f2f01e70c0c23d2224dd21829d86", size = 12712160 }, + { url = "https://files.pythonhosted.org/packages/bd/17/e756653095a083d8a37cbd816cb87148debcfcd920129b25f99dd8d04271/pandas-2.3.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:c4fc4c21971a1a9f4bdb4c73978c7f7256caa3e62b323f70d6cb80db583350bc", size = 13199233 }, + { url = "https://files.pythonhosted.org/packages/04/fd/74903979833db8390b73b3a8a7d30d146d710bd32703724dd9083950386f/pandas-2.3.3-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:ee15f284898e7b246df8087fc82b87b01686f98ee67d85a17b7ab44143a3a9a0", size = 11540635 }, + { url = "https://files.pythonhosted.org/packages/21/00/266d6b357ad5e6d3ad55093a7e8efc7dd245f5a842b584db9f30b0f0a287/pandas-2.3.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:1611aedd912e1ff81ff41c745822980c49ce4a7907537be8692c8dbc31924593", size = 10759079 }, + { url = "https://files.pythonhosted.org/packages/ca/05/d01ef80a7a3a12b2f8bbf16daba1e17c98a2f039cbc8e2f77a2c5a63d382/pandas-2.3.3-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6d2cefc361461662ac48810cb14365a365ce864afe85ef1f447ff5a1e99ea81c", size = 11814049 }, + { url = "https://files.pythonhosted.org/packages/15/b2/0e62f78c0c5ba7e3d2c5945a82456f4fac76c480940f805e0b97fcbc2f65/pandas-2.3.3-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ee67acbbf05014ea6c763beb097e03cd629961c8a632075eeb34247120abcb4b", size = 12332638 }, + { url = "https://files.pythonhosted.org/packages/c5/33/dd70400631b62b9b29c3c93d2feee1d0964dc2bae2e5ad7a6c73a7f25325/pandas-2.3.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c46467899aaa4da076d5abc11084634e2d197e9460643dd455ac3db5856b24d6", size = 12886834 }, + { url = "https://files.pythonhosted.org/packages/d3/18/b5d48f55821228d0d2692b34fd5034bb185e854bdb592e9c640f6290e012/pandas-2.3.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:6253c72c6a1d990a410bc7de641d34053364ef8bcd3126f7e7450125887dffe3", size = 13409925 }, + { url = "https://files.pythonhosted.org/packages/a6/3d/124ac75fcd0ecc09b8fdccb0246ef65e35b012030defb0e0eba2cbbbe948/pandas-2.3.3-cp314-cp314-win_amd64.whl", hash = "sha256:1b07204a219b3b7350abaae088f451860223a52cfb8a6c53358e7948735158e5", size = 11109071 }, + { url = "https://files.pythonhosted.org/packages/89/9c/0e21c895c38a157e0faa1fb64587a9226d6dd46452cac4532d80c3c4a244/pandas-2.3.3-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:2462b1a365b6109d275250baaae7b760fd25c726aaca0054649286bcfbb3e8ec", size = 12048504 }, + { url = "https://files.pythonhosted.org/packages/d7/82/b69a1c95df796858777b68fbe6a81d37443a33319761d7c652ce77797475/pandas-2.3.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:0242fe9a49aa8b4d78a4fa03acb397a58833ef6199e9aa40a95f027bb3a1b6e7", size = 11410702 }, + { url = "https://files.pythonhosted.org/packages/f9/88/702bde3ba0a94b8c73a0181e05144b10f13f29ebfc2150c3a79062a8195d/pandas-2.3.3-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a21d830e78df0a515db2b3d2f5570610f5e6bd2e27749770e8bb7b524b89b450", size = 11634535 }, + { url = "https://files.pythonhosted.org/packages/a4/1e/1bac1a839d12e6a82ec6cb40cda2edde64a2013a66963293696bbf31fbbb/pandas-2.3.3-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2e3ebdb170b5ef78f19bfb71b0dc5dc58775032361fa188e814959b74d726dd5", size = 12121582 }, + { url = "https://files.pythonhosted.org/packages/44/91/483de934193e12a3b1d6ae7c8645d083ff88dec75f46e827562f1e4b4da6/pandas-2.3.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:d051c0e065b94b7a3cea50eb1ec32e912cd96dba41647eb24104b6c6c14c5788", size = 12699963 }, + { url = "https://files.pythonhosted.org/packages/70/44/5191d2e4026f86a2a109053e194d3ba7a31a2d10a9c2348368c63ed4e85a/pandas-2.3.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:3869faf4bd07b3b66a9f462417d0ca3a9df29a9f6abd5d0d0dbab15dac7abe87", size = 13202175 }, +] + [[package]] name = "pandocfilters" version = "1.5.1" @@ -1785,6 +1850,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/08/20/0f2523b9e50a8052bc6a8b732dfc8568abbdc42010aef03a2d750bdab3b2/python_json_logger-3.3.0-py3-none-any.whl", hash = "sha256:dd980fae8cffb24c13caf6e158d3d61c0d6d22342f932cb6e9deedab3d35eec7", size = 15163 }, ] +[[package]] +name = "pytz" +version = "2025.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f8/bf/abbd3cdfb8fbc7fb3d4d38d320f2441b1e7cbe29be4f23797b4a2b5d8aac/pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3", size = 320884 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00", size = 509225 }, +] + [[package]] name = "pywin32" version = "311" @@ -2355,6 +2429,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/26/9f/ad63fc0248c5379346306f8668cda6e2e2e9c95e01216d2b8ffd9ff037d0/typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d", size = 37438 }, ] +[[package]] +name = "tzdata" +version = "2025.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/95/32/1a225d6164441be760d75c2c42e2780dc0873fe382da3e98a2e1e48361e5/tzdata-2025.2.tar.gz", hash = "sha256:b60a638fcc0daffadf82fe0f57e53d06bdec2f36c4df66280ae79bce6bd6f2b9", size = 196380 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/23/c7abc0ca0a1526a0774eca151daeb8de62ec457e77262b66b359c3c7679e/tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8", size = 347839 }, +] + [[package]] name = "uri-template" version = "1.3.0" From bed844a6df3fe78087f15edee4c8e4b65d3d7653 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Mon, 6 Oct 2025 15:13:28 -0400 Subject: [PATCH 21/48] [refactor,feat] constrained types moved to separate module, Added required_field and dtype constraints for contable. --- src/oqd_dataschema/__init__.py | 11 +- src/oqd_dataschema/constrained.py | 174 ++++++++++++++++++++++++++++++ src/oqd_dataschema/dataset.py | 62 +---------- src/oqd_dataschema/table.py | 19 +--- 4 files changed, 186 insertions(+), 80 deletions(-) create mode 100644 src/oqd_dataschema/constrained.py diff --git a/src/oqd_dataschema/__init__.py b/src/oqd_dataschema/__init__.py index df34151..8352fec 100644 --- a/src/oqd_dataschema/__init__.py +++ b/src/oqd_dataschema/__init__.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .dataset import CastDataset, Dataset, condataset +from .constrained import condataset, contable +from .dataset import CastDataset, Dataset from .datastore import Datastore from .group import ( ExpectationValueDataGroup, @@ -22,13 +23,11 @@ OQDTestbenchDataGroup, SinaraRawDataGroup, ) -from .table import Table, contable +from .table import CastTable, Table ######################################################################################## __all__ = [ - "CastDataset", - "Dataset", "Datastore", "GroupBase", "GroupRegistry", @@ -36,8 +35,10 @@ "MeasurementOutcomesDataGroup", "OQDTestbenchDataGroup", "SinaraRawDataGroup", - "condataset", + "Dataset", "CastDataset", + "condataset", "Table", + "CastTable", "contable", ] diff --git a/src/oqd_dataschema/constrained.py b/src/oqd_dataschema/constrained.py new file mode 100644 index 0000000..6a919db --- /dev/null +++ b/src/oqd_dataschema/constrained.py @@ -0,0 +1,174 @@ +# Copyright 2024-2025 Open Quantum Design + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Annotated, Sequence + +from pydantic import AfterValidator + +from oqd_dataschema.dataset import CastDataset +from oqd_dataschema.table import CastTable +from oqd_dataschema.utils import _flex_shape_equal, _validator_from_condition + +######################################################################################## + +__all__ = [ + "contable", + "condataset", +] + +######################################################################################## + + +@_validator_from_condition +def _constraint_dim(model, *, min_dim=None, max_dim=None): + """Constrains the dimension of a Dataset or Table.""" + + if min_dim is not None and max_dim is not None and min_dim > max_dim: + raise ValueError("Impossible to satisfy dimension constraints on dataset.") + + min_dim = 0 if min_dim is None else min_dim + + # fast escape + if min_dim == 0 and max_dim is None: + return + + dims = len(model.shape) + if dims < min_dim or (max_dim is not None and dims > max_dim): + raise ValueError( + f"Expected {min_dim} <= dimension of shape{f' <= {max_dim}'}, but got shape = {model.shape}." + ) + + +@_validator_from_condition +def _constraint_shape(model, *, shape_constraint=None): + """Constrains the shape of a Dataset or Table.""" + + # fast escape + if shape_constraint is None: + return + + if not _flex_shape_equal(shape_constraint, model.shape): + raise ValueError( + f"Expected shape to be {shape_constraint}, but got {model.shape}." + ) + + +######################################################################################## + + +@_validator_from_condition +def _constrain_dtype_dataset(dataset, *, dtype_constraint=None): + """Constrains the dtype of a Dataset.""" + + # fast escape + if dtype_constraint is None: + return + + # convert dtype constraint to set + if (not isinstance(dtype_constraint, str)) and isinstance( + dtype_constraint, Sequence + ): + dtype_constraint = set(dtype_constraint) + elif isinstance(dtype_constraint, str): + dtype_constraint = {dtype_constraint} + + # apply dtype constraint + if dataset.dtype not in dtype_constraint: + raise ValueError( + f"Expected dtype to be of type one of {dtype_constraint}, but got {dataset.dtype}." + ) + + +def condataset( + *, shape_constraint=None, dtype_constraint=None, min_dim=None, max_dim=None +): + """Implements dtype, dimension and shape constrains on the Dataset.""" + return Annotated[ + CastDataset, + AfterValidator(_constrain_dtype_dataset(dtype_constraint=dtype_constraint)), + AfterValidator(_constraint_dim(min_dim=min_dim, max_dim=max_dim)), + AfterValidator(_constraint_shape(shape_constraint=shape_constraint)), + ] + + +######################################################################################## + + +@_validator_from_condition +def _constrain_dtype_table(table, *, dtype_constraint={}): + """Constrains the dtype of a Table.""" + + for k, v in dtype_constraint.items(): + if (not isinstance(v, str)) and isinstance(v, Sequence): + _v = set(dtype_constraint[k]) + elif isinstance(v, str): + _v = {dtype_constraint[k]} + + if _v and dict(table.columns)[k] not in _v: + raise ValueError( + f"Expected dtype to be of type one of {_v}, but got {dict(table.columns)[k]}." + ) + + +@_validator_from_condition +def _constrain_required_field(table, *, required_fields=None, strict_fields=False): + """Constrains the fields of a Table.""" + + if strict_fields and required_fields is None: + raise ValueError("Constraints force an empty Table.") + + # fast escape + if required_fields is None: + return + + # convert required fields to set + if (not isinstance(required_fields, str)) and isinstance(required_fields, Sequence): + required_fields = set(required_fields) + elif isinstance(required_fields, str): + required_fields = {required_fields} + + diff = required_fields.difference(set([c[0] for c in table.columns])) + reverse_diff = set([c[0] for c in table.columns]).difference(required_fields) + + if len(diff) > 0: + raise ValueError(f"Missing required fields {diff}.") + + if strict_fields and len(reverse_diff): + raise ValueError( + f"Extra fields in the table are forbidden by constrains {reverse_diff}." + ) + + +def contable( + *, + required_fields=None, + strict_fields=False, + dtype_constraint={}, + shape_constraint=None, + min_dim=None, + max_dim=None, +): + """Implements dtype, dimension and shape constrains on the Table.""" + return Annotated[ + CastTable, + AfterValidator( + _constrain_required_field( + required_fields=required_fields, strict_fields=strict_fields + ) + ), + AfterValidator(_constrain_dtype_table(dtype_constraint=dtype_constraint)), + AfterValidator(_constraint_dim(min_dim=min_dim, max_dim=max_dim)), + AfterValidator(_constraint_shape(shape_constraint=shape_constraint)), + ] diff --git a/src/oqd_dataschema/dataset.py b/src/oqd_dataschema/dataset.py index 56e4d67..ed836f7 100644 --- a/src/oqd_dataschema/dataset.py +++ b/src/oqd_dataschema/dataset.py @@ -14,11 +14,10 @@ # %% import typing -from typing import Annotated, Any, Literal, Optional, Sequence, Tuple, Union +from typing import Annotated, Any, Literal, Optional, Tuple, Union import numpy as np from pydantic import ( - AfterValidator, BaseModel, BeforeValidator, ConfigDict, @@ -29,14 +28,13 @@ from oqd_dataschema.base import Attrs, DTypes -from .utils import _flex_shape_equal, _validator_from_condition +from .utils import _flex_shape_equal ######################################################################################## __all__ = [ "Dataset", "CastDataset", - "condataset", ] ######################################################################################## @@ -136,59 +134,3 @@ def _is_dataset_type(cls, type_): CastDataset = Annotated[Dataset, BeforeValidator(Dataset.cast)] - - -######################################################################################## - - -@_validator_from_condition -def _constrain_dtype(dataset, *, dtype_constraint=None): - """Constrains the dtype of a dataset""" - if (not isinstance(dtype_constraint, str)) and isinstance( - dtype_constraint, Sequence - ): - dtype_constraint = set(dtype_constraint) - elif isinstance(dtype_constraint, str): - dtype_constraint = {dtype_constraint} - - if dtype_constraint and dataset.dtype not in dtype_constraint: - raise ValueError( - f"Expected dtype to be of type one of {dtype_constraint}, but got {dataset.dtype}." - ) - - -@_validator_from_condition -def _constraint_dim(dataset, *, min_dim=None, max_dim=None): - """Constrains the dimension of a dataset""" - if min_dim is not None and max_dim is not None and min_dim > max_dim: - raise ValueError("Impossible to satisfy dimension constraints on dataset.") - - min_dim = 0 if min_dim is None else min_dim - - dims = len(dataset.shape) - - if dims < min_dim or (max_dim is not None and dims > max_dim): - raise ValueError( - f"Expected {min_dim} <= dimension of shape{f' <= {max_dim}'}, but got shape = {dataset.shape}." - ) - - -@_validator_from_condition -def _constraint_shape(dataset, *, shape_constraint=None): - """Constrains the shape of a dataset""" - if shape_constraint and not _flex_shape_equal(shape_constraint, dataset.shape): - raise ValueError( - f"Expected shape to be {shape_constraint}, but got {dataset.shape}." - ) - - -def condataset( - *, shape_constraint=None, dtype_constraint=None, min_dim=None, max_dim=None -): - """Implements dtype, dimension and shape constrains on the Dataset.""" - return Annotated[ - CastDataset, - AfterValidator(_constrain_dtype(dtype_constraint=dtype_constraint)), - AfterValidator(_constraint_dim(min_dim=min_dim, max_dim=max_dim)), - AfterValidator(_constraint_shape(shape_constraint=shape_constraint)), - ] diff --git a/src/oqd_dataschema/table.py b/src/oqd_dataschema/table.py index ddbcd18..0a54357 100644 --- a/src/oqd_dataschema/table.py +++ b/src/oqd_dataschema/table.py @@ -20,7 +20,6 @@ import numpy as np import pandas as pd from pydantic import ( - AfterValidator, BaseModel, BeforeValidator, ConfigDict, @@ -30,15 +29,16 @@ ) from oqd_dataschema.base import Attrs, DTypes -from oqd_dataschema.dataset import _constraint_dim, _constraint_shape -from oqd_dataschema.utils import _flex_shape_equal, _is_list_unique +from oqd_dataschema.utils import ( + _flex_shape_equal, + _is_list_unique, +) ######################################################################################## __all__ = [ "Table", "CastTable", - "contable", ] ######################################################################################## @@ -189,14 +189,3 @@ def _is_table_type(cls, type_): CastTable = Annotated[Table, BeforeValidator(Table.cast)] - -######################################################################################## - - -def contable(*, shape_constraint=None, min_dim=None, max_dim=None): - """Implements dtype, dimension and shape constrains on the Table.""" - return Annotated[ - Table, - AfterValidator(_constraint_dim(min_dim=min_dim, max_dim=max_dim)), - AfterValidator(_constraint_shape(shape_constraint=shape_constraint)), - ] From 3fd87d9d24f3f4464b481efe747d19751abde5d5 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Mon, 6 Oct 2025 15:33:59 -0400 Subject: [PATCH 22/48] [feat] Table uses np.recarray for data --- src/oqd_dataschema/table.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/oqd_dataschema/table.py b/src/oqd_dataschema/table.py index 0a54357..dadfafe 100644 --- a/src/oqd_dataschema/table.py +++ b/src/oqd_dataschema/table.py @@ -121,6 +121,9 @@ def validate_and_update(cls, value): if not isinstance(value.dtype.fields, MappingProxyType): raise TypeError("dtype of data must be a structured dtype.") + if isinstance(value, np.ndarray): + value = value.view(np.recarray) + return value @model_validator(mode="after") From c0ced5459737a1d99741c4d14a9a6b9af69cef09 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Mon, 6 Oct 2025 17:55:13 -0400 Subject: [PATCH 23/48] [fix] error message for GroupBase, datafield can be Dataset or Table --- src/oqd_dataschema/group.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/oqd_dataschema/group.py b/src/oqd_dataschema/group.py index dc88871..61cbeb9 100644 --- a/src/oqd_dataschema/group.py +++ b/src/oqd_dataschema/group.py @@ -108,7 +108,7 @@ def __init_subclass__(cls, **kwargs): if not cls._is_allowed_field_type(v): raise TypeError( - "All fields of `GroupBase` have to be of type `Dataset`." + "All fields of `GroupBase` have to be of type `Dataset` or `Table`." ) cls.__annotations__["class_"] = Literal[cls.__name__] From a9a29bec956504eff182ac37da2ee75c46ce083c Mon Sep 17 00:00:00 2001 From: yhteoh Date: Mon, 6 Oct 2025 18:03:58 -0400 Subject: [PATCH 24/48] [feat] Implemented Folder for document type datastore. --- src/oqd_dataschema/folder.py | 138 +++++++++++++++++++++++++++++++++++ src/oqd_dataschema/group.py | 9 ++- 2 files changed, 145 insertions(+), 2 deletions(-) create mode 100644 src/oqd_dataschema/folder.py diff --git a/src/oqd_dataschema/folder.py b/src/oqd_dataschema/folder.py new file mode 100644 index 0000000..c4e7821 --- /dev/null +++ b/src/oqd_dataschema/folder.py @@ -0,0 +1,138 @@ +# Copyright 2024-2025 Open Quantum Design + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import typing +from types import MappingProxyType +from typing import Annotated, Any, Dict, Literal, Optional, Tuple, Union + +import numpy as np +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, + model_validator, +) +from typing_extensions import TypeAliasType + +from oqd_dataschema.base import Attrs, DTypes +from oqd_dataschema.utils import _flex_shape_equal + +######################################################################################## + +__all__ = [ + "Folder", +] + +######################################################################################## + +DocumentSchema = TypeAliasType( + "DocumentSchema", + Dict[str, Union["DocumentSchema", Optional[Literal[DTypes.names()]]]], # type: ignore +) + + +class Folder(BaseModel, extra="forbid"): + document_schema: DocumentSchema + shape: Optional[Tuple[Union[int, None], ...]] = None + data: Optional[Any] = Field(default=None, exclude=True) + + attrs: Attrs = {} + + model_config = ConfigDict( + use_enum_values=False, arbitrary_types_allowed=True, validate_assignment=True + ) + + @field_validator("data", mode="before") + @classmethod + def validate_and_update(cls, value): + # check if data exist + if value is None: + return value + + # check if data is a numpy array + if not isinstance(value, np.ndarray): + raise TypeError("`data` must be a numpy.ndarray.") + + if not isinstance(value.dtype.fields, MappingProxyType): + raise TypeError("dtype of data must be a structured dtype.") + + value = value.view(np.recarray) + + return value + + @staticmethod + def _is_valid_array(document_schema, data_dtype, position=""): + # check if data_dtype is a structured dtype + if not isinstance(data_dtype.fields, MappingProxyType): + raise TypeError( + f"Error {f'in key `{position}`' if position else 'at root'}, expected structured dtype matching {document_schema = } but got unstructured dtype {data_dtype = }." + ) + + # check if fields all match + if set(document_schema.keys()) != set(data_dtype.fields.keys()): + diff = set(document_schema.keys()).difference(set(data_dtype.fields.keys())) + rv_diff = set(data_dtype.fields.keys()).difference( + set(document_schema.keys()) + ) + raise ValueError( + f"Error {f'in key `{position}`' if position else 'at root '}, mismatched {'subkeys' if position else 'keys'} between `document_schema` (unmatched = {diff}) and numpy data structured dtype (unmatched = {rv_diff})." + ) + + # recursively check document_schema matches structured dtype data_dtype + for k, v in document_schema.items(): + if isinstance(v, dict): + Folder._is_valid_array( + v, data_dtype.fields[k][0], position + "." + k if position else k + ) + continue + + # check if dtypes match + if ( + v is not None + and type(data_dtype.fields[k][0]) is not DTypes.get(v).value + ): + raise ValueError( + f"Error {f'in key `{position}`' if position else 'at root '}, expected {'subkey' if position else 'key'} `{k}` to be of dtype compatible with {v} but got dtype {data_dtype.fields[k][0]}." + ) + + @model_validator(mode="after") + def validate_data_matches_shape_dtype(self): + """Ensure that `data` matches `dtype` and `shape`.""" + + # check if data exist + if self.data is None: + return self + + # check if document_schema matches the data's structured dtype + self._is_valid_array(self.document_schema, self.data.dtype) + + # check if shape mataches data + if self.shape is not None and not _flex_shape_equal( + self.data.shape, self.shape + ): + raise ValueError(f"Expected shape {self.shape}, but got {self.data.shape}.") + + # resassign shape to concrete value if it is None or a flexible shape + if self.shape != self.data.shape: + self.shape = self.data.shape + + return self + + @classmethod + def _is_folder_type(cls, type_): + return type_ == cls or ( + typing.get_origin(type_) is Annotated and type_.__origin__ is cls + ) diff --git a/src/oqd_dataschema/group.py b/src/oqd_dataschema/group.py index 61cbeb9..3f478ce 100644 --- a/src/oqd_dataschema/group.py +++ b/src/oqd_dataschema/group.py @@ -25,6 +25,7 @@ from oqd_dataschema.base import Attrs from oqd_dataschema.dataset import CastDataset, Dataset +from oqd_dataschema.folder import Folder from oqd_dataschema.table import Table ######################################################################################## @@ -62,7 +63,11 @@ class GroupBase(BaseModel, extra="forbid"): @staticmethod def _is_datafield_type(v): - return Dataset._is_dataset_type(v) or Table._is_table_type(v) + return ( + Dataset._is_dataset_type(v) + or Table._is_table_type(v) + or Folder._is_folder_type(v) + ) @classmethod def _is_allowed_field_type(cls, v): @@ -108,7 +113,7 @@ def __init_subclass__(cls, **kwargs): if not cls._is_allowed_field_type(v): raise TypeError( - "All fields of `GroupBase` have to be of type `Dataset` or `Table`." + "All fields of `GroupBase` have to be of type `Dataset`, `Table` or `Folder`." ) cls.__annotations__["class_"] = Literal[cls.__name__] From d42791cd91f00f19dc4a6d41a72312e88ad05c0d Mon Sep 17 00:00:00 2001 From: yhteoh Date: Tue, 7 Oct 2025 08:51:54 -0400 Subject: [PATCH 25/48] [refactor] rename pd property of Table to dataframe --- src/oqd_dataschema/table.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/oqd_dataschema/table.py b/src/oqd_dataschema/table.py index dadfafe..3df9b16 100644 --- a/src/oqd_dataschema/table.py +++ b/src/oqd_dataschema/table.py @@ -70,7 +70,7 @@ def validate_unique(cls, value): return value @property - def pd(self): + def dataframe(self): if len(self.shape) > 1: raise ValueError( "Conversion to pandas DataFrame only supported on 1D Table." From 868cb06b37e8c5ae13c2c27096c1aae67d498d2e Mon Sep 17 00:00:00 2001 From: yhteoh Date: Tue, 7 Oct 2025 09:50:12 -0400 Subject: [PATCH 26/48] [format] formatted files --- .gitattributes | 2 +- .github/ISSUE_TEMPLATE.md | 2 +- .github/workflows/check_copyright.yml | 4 ++-- .github/workflows/check_mkdocs_build.yml | 2 +- .github/workflows/copyright.txt | 2 +- README.md | 2 +- docs/api.md | 2 +- docs/index.md | 2 +- docs/stylesheets/admonition_template.css | 2 +- docs/stylesheets/admonitions.css | 3 --- docs/stylesheets/brand.css | 18 +++++++++--------- docs/tutorial.md | 2 +- 12 files changed, 20 insertions(+), 23 deletions(-) diff --git a/.gitattributes b/.gitattributes index 235b1a2..a76e4dc 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1 +1 @@ -postprocessing/** linguist-vendored \ No newline at end of file +postprocessing/** linguist-vendored diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index cda0ced..7de49f3 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -26,4 +26,4 @@ - Subsystem: -* **Other information** (e.g. detailed explanation, stacktraces, related issues, suggestions how to fix, links for us to have context, eg. stackoverflow, gitter, etc) \ No newline at end of file +* **Other information** (e.g. detailed explanation, stacktraces, related issues, suggestions how to fix, links for us to have context, eg. stackoverflow, gitter, etc) diff --git a/.github/workflows/check_copyright.yml b/.github/workflows/check_copyright.yml index 9ed4c48..fdc46fc 100644 --- a/.github/workflows/check_copyright.yml +++ b/.github/workflows/check_copyright.yml @@ -10,6 +10,6 @@ jobs: - name: Check license & copyright headers uses: viperproject/check-license-header@v2 with: - path: + path: config: .github/workflows/check_copyright_config.json - # strict: true \ No newline at end of file + # strict: true diff --git a/.github/workflows/check_mkdocs_build.yml b/.github/workflows/check_mkdocs_build.yml index 116f6d4..2f9c08b 100644 --- a/.github/workflows/check_mkdocs_build.yml +++ b/.github/workflows/check_mkdocs_build.yml @@ -21,4 +21,4 @@ jobs: uses: astral-sh/setup-uv@v4 # - run: cp -r examples/ docs/examples/ - run: uv pip install .[docs] --system - - run: mkdocs build \ No newline at end of file + - run: mkdocs build diff --git a/.github/workflows/copyright.txt b/.github/workflows/copyright.txt index 67eb334..084ae79 100644 --- a/.github/workflows/copyright.txt +++ b/.github/workflows/copyright.txt @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/README.md b/README.md index 9dc494c..ab3b8b2 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ The design goals are to have: - Minimizes *a priori* knowledge that is needed of the internal heirarchical structure, reducing friction for users to load data. - Transparently return both raw and processed data, where the levels of post-processing can be selected by the user. -To install, +To install, ```bash pip install git+https://github.com/OpenQuantumDesign/oqd-dataschema.git ``` diff --git a/docs/api.md b/docs/api.md index 1621f49..0aca668 100644 --- a/docs/api.md +++ b/docs/api.md @@ -30,4 +30,4 @@ "SinaraRawDataGroup", "MeasurementOutcomesDataGroup", "ExpectationValueDataGroup", - ] \ No newline at end of file + ] diff --git a/docs/index.md b/docs/index.md index 1355f50..9d7bf57 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,4 +1,4 @@ -# +#

Logo diff --git a/docs/stylesheets/admonition_template.css b/docs/stylesheets/admonition_template.css index f32ebeb..94fee5c 100644 --- a/docs/stylesheets/admonition_template.css +++ b/docs/stylesheets/admonition_template.css @@ -14,4 +14,4 @@ background-color: #FFFFFF; -webkit-mask-image: var(--md-admonition-icon--template); mask-image: var(--md-admonition-icon--template); - } \ No newline at end of file + } diff --git a/docs/stylesheets/admonitions.css b/docs/stylesheets/admonitions.css index ff86542..eb1babd 100644 --- a/docs/stylesheets/admonitions.css +++ b/docs/stylesheets/admonitions.css @@ -130,6 +130,3 @@ -webkit-mask-image: var(--md-admonition-icon--acknowledgement); mask-image: var(--md-admonition-icon--acknowledgement); } - - - diff --git a/docs/stylesheets/brand.css b/docs/stylesheets/brand.css index d2fdeeb..34d6112 100644 --- a/docs/stylesheets/brand.css +++ b/docs/stylesheets/brand.css @@ -36,12 +36,12 @@ h1, h2, h3, h4, h5, h6, /* Apply Raleway to all navigation and sidebar elements */ -.md-nav, -.md-nav__title, -.md-nav__link, -.md-header, -.md-tabs, -.md-sidebar, +.md-nav, +.md-nav__title, +.md-nav__link, +.md-header, +.md-tabs, +.md-sidebar, .md-sidebar__inner, .md-nav__item, .md-footer, @@ -79,7 +79,7 @@ h1, h2, h3, h4, h5, h6, /* Light mode nav/ToC font color */ -[data-md-color-scheme="default"] .md-nav, +[data-md-color-scheme="default"] .md-nav, [data-md-color-scheme="default"] .md-nav__link, [data-md-color-scheme="default"] .md-header, [data-md-color-scheme="default"] .md-tabs { @@ -88,7 +88,7 @@ h1, h2, h3, h4, h5, h6, } /* Dark mode nav/ToC font color */ -[data-md-color-scheme="slate"] .md-nav, +[data-md-color-scheme="slate"] .md-nav, [data-md-color-scheme="slate"] .md-nav__link, [data-md-color-scheme="slate"] .md-header, [data-md-color-scheme="slate"] .md-tabs { @@ -113,4 +113,4 @@ h1, h2, h3, h4, h5, h6, .md-header .md-tabs__link:hover { color: #ffffff !important; text-decoration: underline; -} \ No newline at end of file +} diff --git a/docs/tutorial.md b/docs/tutorial.md index daa4bfa..943dc1c 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -105,4 +105,4 @@ data.model_dump_hdf5(filepath) ```python data_reload = Datastore.model_validate_hdf5(filepath) pprint(data_reload) -``` \ No newline at end of file +``` From 9a70409f525b8563a744dd7d0f3b7ab166374399 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Tue, 7 Oct 2025 09:57:16 -0400 Subject: [PATCH 27/48] [refactor,feat] Implemented abstract superclass GroupField for (Dataset, Table, Folder) and move str dtype dump and load logic from Datastore to GroupField subclass --- src/oqd_dataschema/base.py | 28 ++++++++++++--- src/oqd_dataschema/dataset.py | 22 +++++++----- src/oqd_dataschema/datastore.py | 62 +++++---------------------------- src/oqd_dataschema/folder.py | 54 ++++++++++++++++++++++------ src/oqd_dataschema/group.py | 33 +++++++++--------- src/oqd_dataschema/table.py | 34 +++++++++++++----- 6 files changed, 133 insertions(+), 100 deletions(-) diff --git a/src/oqd_dataschema/base.py b/src/oqd_dataschema/base.py index d91e7b8..0dd9098 100644 --- a/src/oqd_dataschema/base.py +++ b/src/oqd_dataschema/base.py @@ -13,20 +13,20 @@ # limitations under the License. # %% +import typing +from abc import ABC, abstractmethod from enum import Enum from typing import Annotated, Optional, Union import numpy as np from pydantic import ( + BaseModel, BeforeValidator, ) ######################################################################################## -__all__ = [ - "Attrs", - "DTypes", -] +__all__ = ["Attrs", "DTypes", "GroupField"] ######################################################################################## @@ -75,3 +75,23 @@ def _valid_attr_key(value): Union[int, float, str, complex], ] ] + +######################################################################################## + + +class GroupField(BaseModel, ABC): + attrs: Attrs + + @classmethod + def _is_supported_type(cls, type_): + return type_ == cls or ( + typing.get_origin(type_) is Annotated and type_.__origin__ is cls + ) + + @abstractmethod + def _handle_data_dump(self, data): + pass + + @abstractmethod + def _handle_data_load(self, data): + pass diff --git a/src/oqd_dataschema/dataset.py b/src/oqd_dataschema/dataset.py index ed836f7..1afa6ca 100644 --- a/src/oqd_dataschema/dataset.py +++ b/src/oqd_dataschema/dataset.py @@ -13,12 +13,10 @@ # limitations under the License. # %% -import typing from typing import Annotated, Any, Literal, Optional, Tuple, Union import numpy as np from pydantic import ( - BaseModel, BeforeValidator, ConfigDict, Field, @@ -26,7 +24,7 @@ model_validator, ) -from oqd_dataschema.base import Attrs, DTypes +from oqd_dataschema.base import Attrs, DTypes, GroupField from .utils import _flex_shape_equal @@ -40,7 +38,7 @@ ######################################################################################## -class Dataset(BaseModel, extra="forbid"): +class Dataset(GroupField, extra="forbid"): """ Schema representation for a dataset object to be saved within an HDF5 file. @@ -126,11 +124,19 @@ def cast(cls, data): def __getitem__(self, idx): return self.data[idx] - @classmethod - def _is_dataset_type(cls, type_): - return type_ == cls or ( - typing.get_origin(type_) is Annotated and type_.__origin__ is cls + def _handle_data_dump(self, data): + np_dtype = ( + np.dtypes.BytesDType if type(data.dtype) is np.dtypes.StrDType else None ) + if np_dtype is None: + return data + + return data.astype(np_dtype) + + def _handle_data_load(self, data): + np_dtype = DTypes.get(self.dtype).value + return data.astype(np_dtype) + CastDataset = Annotated[Dataset, BeforeValidator(Dataset.cast)] diff --git a/src/oqd_dataschema/datastore.py b/src/oqd_dataschema/datastore.py index f581adb..88b5760 100644 --- a/src/oqd_dataschema/datastore.py +++ b/src/oqd_dataschema/datastore.py @@ -19,16 +19,13 @@ from typing import Any, Dict, Literal import h5py -import numpy as np from pydantic import ( BaseModel, field_validator, ) -from oqd_dataschema.base import Attrs, DTypes -from oqd_dataschema.dataset import Dataset +from oqd_dataschema.base import Attrs, GroupField from oqd_dataschema.group import GroupBase, GroupRegistry -from oqd_dataschema.table import Table ######################################################################################## @@ -109,11 +106,7 @@ def _dump_group(self, h5datastore, gkey, group): def _dump_dataset(self, h5group, dkey, dataset): """Helper function for dumping Dataset.""" - if ( - dataset is not None - and not isinstance(dataset, Dataset) - and not isinstance(dataset, Table) - ): + if dataset is not None and not isinstance(dataset, GroupField): raise ValueError("Group data field is not a Dataset or a Table.") # handle optional dataset @@ -121,33 +114,10 @@ def _dump_dataset(self, h5group, dkey, dataset): h5_dataset = h5group.create_dataset(dkey, data=h5py.Empty("f")) return - if isinstance(dataset, Dataset): - # dtype str converted to bytes when dumped (h5 compatibility) - np_dtype = ( - np.dtypes.BytesDType - if dataset.dtype == "str" - else DTypes.get(dataset.dtype).value - ) - - h5_dataset = h5group.create_dataset( - dkey, data=dataset.data.astype(np_dtype) - ) - - if isinstance(dataset, Table): - # dtype str converted to bytes when dumped (h5 compatibility) - np_dtype = np.dtype( - [ - (k, np.empty(0, dtype=v).astype(np.dtypes.BytesDType).dtype) - if dict(dataset.columns)[k] == "str" - else (k, v) - for k, (v, _) in dataset.data.dtype.fields.items() - ] - ) - - h5_dataset = h5group.create_dataset( - dkey, - data=dataset.data.astype(np_dtype), - ) + # dtype str converted to bytes when dumped (h5 compatibility) + h5_dataset = h5group.create_dataset( + dkey, data=dataset._handle_data_dump(dataset.data) + ) # dump dataset attributes for akey, attr in dataset.attrs.items(): @@ -182,24 +152,8 @@ def _load_data(cls, group, h5group, dkey, ikey=None): field = group.__dict__[ikey] if ikey else group.__dict__ h5field = h5group[ikey] if ikey else h5group - if isinstance(field[dkey], Dataset): - field[dkey].data = np.array(h5field[dkey][()]).astype( - DTypes.get(field[dkey].dtype).value - ) - return - if isinstance(field[dkey], Table): - np_dtype = np.dtype( - [ - ( - k, - np.empty(0, dtype=v).astype(np.dtypes.StrDType).dtype, - ) - if dict(field[dkey].columns)[k] == "str" - else (k, v) - for k, (v, _) in np.array(h5field[dkey][()]).dtype.fields.items() - ] - ) - field[dkey].data = np.array(h5field[dkey][()]).astype(np_dtype) + if isinstance(field[dkey], GroupField): + field[dkey].data = field[dkey]._handle_data_load(h5field[dkey][()]) return raise ValueError( diff --git a/src/oqd_dataschema/folder.py b/src/oqd_dataschema/folder.py index c4e7821..c00230d 100644 --- a/src/oqd_dataschema/folder.py +++ b/src/oqd_dataschema/folder.py @@ -13,13 +13,11 @@ # limitations under the License. -import typing from types import MappingProxyType -from typing import Annotated, Any, Dict, Literal, Optional, Tuple, Union +from typing import Any, Dict, Literal, Optional, Tuple, Union import numpy as np from pydantic import ( - BaseModel, ConfigDict, Field, field_validator, @@ -27,7 +25,7 @@ ) from typing_extensions import TypeAliasType -from oqd_dataschema.base import Attrs, DTypes +from oqd_dataschema.base import Attrs, DTypes, GroupField from oqd_dataschema.utils import _flex_shape_equal ######################################################################################## @@ -44,7 +42,7 @@ ) -class Folder(BaseModel, extra="forbid"): +class Folder(GroupField, extra="forbid"): document_schema: DocumentSchema shape: Optional[Tuple[Union[int, None], ...]] = None data: Optional[Any] = Field(default=None, exclude=True) @@ -131,8 +129,44 @@ def validate_data_matches_shape_dtype(self): return self - @classmethod - def _is_folder_type(cls, type_): - return type_ == cls or ( - typing.get_origin(type_) is Annotated and type_.__origin__ is cls - ) + @staticmethod + def _dump_dtype_str_to_bytes(dtype): + np_dtype = [] + + for k, (v, _) in dtype.fields.items(): + if isinstance(v.fields, MappingProxyType): + dt = Folder._dump_dtype_str_to_bytes(v) + elif type(v) is np.dtypes.StrDType: + dt = np.empty(0, dtype=v).astype(np.dtypes.BytesDType).dtype + else: + dt = v + + np_dtype.append((k, dt)) + + return np.dtype(np_dtype) + + def _handle_data_dump(self, data): + np_dtype = self._dump_dtype_str_to_bytes(data.dtype) + + return data.astype(np_dtype) + + @staticmethod + def _load_dtype_bytes_to_str(document_schema, dtype): + np_dtype = [] + + for k, (v, _) in dtype.fields.items(): + if isinstance(v.fields, MappingProxyType): + dt = Folder._load_dtype_bytes_to_str(document_schema[k], v) + elif document_schema[k] == "str": + dt = np.empty(0, dtype=v).astype(np.dtypes.StrDType).dtype + else: + dt = v + + np_dtype.append((k, dt)) + + return np.dtype(np_dtype) + + def _handle_data_load(self, data): + np_dtype = self._load_dtype_bytes_to_str(self.document_schema, data.dtype) + + return data.astype(np_dtype) diff --git a/src/oqd_dataschema/group.py b/src/oqd_dataschema/group.py index 3f478ce..2217b77 100644 --- a/src/oqd_dataschema/group.py +++ b/src/oqd_dataschema/group.py @@ -14,6 +14,7 @@ import typing import warnings +from functools import reduce from types import NoneType from typing import Annotated, ClassVar, Literal, Union @@ -23,10 +24,8 @@ TypeAdapter, ) -from oqd_dataschema.base import Attrs -from oqd_dataschema.dataset import CastDataset, Dataset -from oqd_dataschema.folder import Folder -from oqd_dataschema.table import Table +from oqd_dataschema.base import Attrs, GroupField +from oqd_dataschema.dataset import CastDataset ######################################################################################## @@ -62,30 +61,32 @@ class GroupBase(BaseModel, extra="forbid"): attrs: Attrs = {} @staticmethod - def _is_datafield_type(v): - return ( - Dataset._is_dataset_type(v) - or Table._is_table_type(v) - or Folder._is_folder_type(v) + def _is_basic_groupfield_type(v): + return reduce( + lambda x, y: x or y, + (gf._is_supported_type(v) for gf in GroupField.__subclasses__()), ) @classmethod - def _is_allowed_field_type(cls, v): - is_datafield = cls._is_datafield_type(v) + def _is_groupfield_type(cls, v): + is_datafield = cls._is_basic_groupfield_type(v) is_annotated_datafield = typing.get_origin( v - ) is Annotated and cls._is_datafield_type(v.__origin__) + ) is Annotated and cls._is_basic_groupfield_type(v.__origin__) is_optional_datafield = typing.get_origin(v) is Union and ( - (v.__args__[0] == NoneType and cls._is_datafield_type(v.__args__[1])) - or (v.__args__[1] == NoneType and cls._is_datafield_type(v.__args__[0])) + (v.__args__[0] == NoneType and cls._is_basic_groupfield_type(v.__args__[1])) + or ( + v.__args__[1] == NoneType + and cls._is_basic_groupfield_type(v.__args__[0]) + ) ) is_dict_datafield = ( typing.get_origin(v) is dict and v.__args__[0] is str - and cls._is_datafield_type(v.__args__[1]) + and cls._is_basic_groupfield_type(v.__args__[1]) ) return ( @@ -111,7 +112,7 @@ def __init_subclass__(cls, **kwargs): if cls._is_classvar(v): continue - if not cls._is_allowed_field_type(v): + if not cls._is_groupfield_type(v): raise TypeError( "All fields of `GroupBase` have to be of type `Dataset`, `Table` or `Folder`." ) diff --git a/src/oqd_dataschema/table.py b/src/oqd_dataschema/table.py index 3df9b16..064026d 100644 --- a/src/oqd_dataschema/table.py +++ b/src/oqd_dataschema/table.py @@ -13,14 +13,12 @@ # limitations under the License. -import typing from types import MappingProxyType from typing import Annotated, Any, List, Literal, Optional, Tuple, Union import numpy as np import pandas as pd from pydantic import ( - BaseModel, BeforeValidator, ConfigDict, Field, @@ -28,7 +26,7 @@ model_validator, ) -from oqd_dataschema.base import Attrs, DTypes +from oqd_dataschema.base import Attrs, DTypes, GroupField from oqd_dataschema.utils import ( _flex_shape_equal, _is_list_unique, @@ -47,7 +45,7 @@ Column = Tuple[str, Optional[Literal[DTypes.names()]]] -class Table(BaseModel, extra="forbid"): +class Table(GroupField, extra="forbid"): columns: List[Column] # type: ignore shape: Optional[Tuple[Union[int, None], ...]] = None data: Optional[Any] = Field(default=None, exclude=True) @@ -184,11 +182,31 @@ def cast(cls, data): return cls(columns=columns, data=data) return data - @classmethod - def _is_table_type(cls, type_): - return type_ == cls or ( - typing.get_origin(type_) is Annotated and type_.__origin__ is cls + def _handle_data_dump(self, data): + np_dtype = np.dtype( + [ + (k, np.empty(0, dtype=v).astype(np.dtypes.BytesDType).dtype) + if type(v) is np.dtypes.StrDType + else (k, v) + for k, (v, _) in data.dtype.fields.items() + ] + ) + + return data.astype(np_dtype) + + def _handle_data_load(self, data): + np_dtype = np.dtype( + [ + ( + k, + np.empty(0, dtype=v).astype(np.dtypes.StrDType).dtype, + ) + if dict(self.columns)[k] == "str" + else (k, v) + for k, (v, _) in np.array(data).dtype.fields.items() + ] ) + return data.astype(np_dtype) CastTable = Annotated[Table, BeforeValidator(Table.cast)] From 3bfcda31dcc5354ee861f6c1cdf277af3f7143ba Mon Sep 17 00:00:00 2001 From: yhteoh Date: Tue, 7 Oct 2025 11:05:31 -0400 Subject: [PATCH 28/48] [fix] Modify document schema None dtypes to concrete values when loading data in Folder --- src/oqd_dataschema/__init__.py | 2 ++ src/oqd_dataschema/folder.py | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/src/oqd_dataschema/__init__.py b/src/oqd_dataschema/__init__.py index 8352fec..ce79068 100644 --- a/src/oqd_dataschema/__init__.py +++ b/src/oqd_dataschema/__init__.py @@ -15,6 +15,7 @@ from .constrained import condataset, contable from .dataset import CastDataset, Dataset from .datastore import Datastore +from .folder import Folder from .group import ( ExpectationValueDataGroup, GroupBase, @@ -41,4 +42,5 @@ "Table", "CastTable", "contable", + "Folder", ] diff --git a/src/oqd_dataschema/folder.py b/src/oqd_dataschema/folder.py index c00230d..e05d899 100644 --- a/src/oqd_dataschema/folder.py +++ b/src/oqd_dataschema/folder.py @@ -123,12 +123,33 @@ def validate_data_matches_shape_dtype(self): ): raise ValueError(f"Expected shape {self.shape}, but got {self.data.shape}.") + # reassign dtype if it is None + document_schema_from_dtype = self._get_document_schema_from_dtype( + self.data.dtype + ) + if self.document_schema != document_schema_from_dtype: + self.document_schema = document_schema_from_dtype + # resassign shape to concrete value if it is None or a flexible shape if self.shape != self.data.shape: self.shape = self.data.shape return self + @staticmethod + def _get_document_schema_from_dtype(dtype): + document_schema = {} + + for k, (v, _) in dtype.fields.items(): + if isinstance(v.fields, MappingProxyType): + dt = Folder._get_document_schema_from_dtype(v) + else: + dt = DTypes(type(v)).name.lower() + + document_schema[k] = dt + + return document_schema + @staticmethod def _dump_dtype_str_to_bytes(dtype): np_dtype = [] From b6c69b45ab964604b2b5ab0d4f3438e3345d8453 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Tue, 7 Oct 2025 12:33:42 -0400 Subject: [PATCH 29/48] [feat] Added function for constructing numpy datatype from Table and Folder instance --- src/oqd_dataschema/folder.py | 30 ++++++++++++++++++++++++++++++ src/oqd_dataschema/table.py | 19 +++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/src/oqd_dataschema/folder.py b/src/oqd_dataschema/folder.py index e05d899..2479176 100644 --- a/src/oqd_dataschema/folder.py +++ b/src/oqd_dataschema/folder.py @@ -150,6 +150,36 @@ def _get_document_schema_from_dtype(dtype): return document_schema + @staticmethod + def _numpy_dtype(document_schema, *, str_size=64, bytes_size=64): + np_dtype = [] + + for k, v in document_schema.items(): + if v is None: + raise ValueError( + "Method numpy_dtype can only be called on concrete types." + ) + + if isinstance(v, dict): + dt = Folder._numpy_dtype( + document_schema[k], str_size=str_size, bytes_size=bytes_size + ) + elif v == "str": + dt = np.dtypes.StrDType(str_size) + elif v == "bytes": + dt = np.dtypes.BytesDType(bytes_size) + else: + dt = DTypes.get(v).value() + + np_dtype.append((k, dt)) + + return np.dtype(np_dtype) + + def numpy_dtype(self, *, str_size=64, bytes_size=64): + return self._numpy_dtype( + self.document_schema, str_size=str_size, bytes_size=bytes_size + ) + @staticmethod def _dump_dtype_str_to_bytes(dtype): np_dtype = [] diff --git a/src/oqd_dataschema/table.py b/src/oqd_dataschema/table.py index 064026d..fc5a8da 100644 --- a/src/oqd_dataschema/table.py +++ b/src/oqd_dataschema/table.py @@ -165,6 +165,25 @@ def validate_data_matches_shape_dtype(self): return self + def numpy_dtype(self, *, str_size=64, bytes_size=64): + np_dtype = [] + + for k, v in self.columns: + if v is None: + raise ValueError( + "Method numpy_dtype can only be called on concrete types." + ) + if v == "str": + dt = np.dtypes.StrDType(str_size) + elif v == "bytes": + dt = np.dtypes.BytesDType(bytes_size) + else: + dt = DTypes.get(v).value() + + np_dtype.append((k, dt)) + + return np.dtype(np_dtype) + @classmethod def cast(cls, data): if isinstance(data, pd.DataFrame): From ab22d7ac4e878e28db6dc32bd881df515e470204 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Tue, 7 Oct 2025 12:39:54 -0400 Subject: [PATCH 30/48] [feat] Implemented CastFolder automatic casting of structured numpy array and confolder with shape and dimension constraints --- src/oqd_dataschema/__init__.py | 6 ++++-- src/oqd_dataschema/constrained.py | 23 +++++++++++++++++++---- src/oqd_dataschema/folder.py | 21 +++++++++++++++++---- 3 files changed, 40 insertions(+), 10 deletions(-) diff --git a/src/oqd_dataschema/__init__.py b/src/oqd_dataschema/__init__.py index ce79068..da9b007 100644 --- a/src/oqd_dataschema/__init__.py +++ b/src/oqd_dataschema/__init__.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .constrained import condataset, contable +from .constrained import condataset, confolder, contable from .dataset import CastDataset, Dataset from .datastore import Datastore -from .folder import Folder +from .folder import CastFolder, Folder from .group import ( ExpectationValueDataGroup, GroupBase, @@ -43,4 +43,6 @@ "CastTable", "contable", "Folder", + "CastFolder", + "confolder", ] diff --git a/src/oqd_dataschema/constrained.py b/src/oqd_dataschema/constrained.py index 6a919db..871d190 100644 --- a/src/oqd_dataschema/constrained.py +++ b/src/oqd_dataschema/constrained.py @@ -18,15 +18,13 @@ from pydantic import AfterValidator from oqd_dataschema.dataset import CastDataset +from oqd_dataschema.folder import Folder from oqd_dataschema.table import CastTable from oqd_dataschema.utils import _flex_shape_equal, _validator_from_condition ######################################################################################## -__all__ = [ - "contable", - "condataset", -] +__all__ = ["contable", "condataset", "confolder"] ######################################################################################## @@ -172,3 +170,20 @@ def contable( AfterValidator(_constraint_dim(min_dim=min_dim, max_dim=max_dim)), AfterValidator(_constraint_shape(shape_constraint=shape_constraint)), ] + + +######################################################################################## + + +def confolder( + *, + shape_constraint=None, + min_dim=None, + max_dim=None, +): + """Implements dtype, dimension and shape constrains on the Folder.""" + return Annotated[ + Folder, + AfterValidator(_constraint_dim(min_dim=min_dim, max_dim=max_dim)), + AfterValidator(_constraint_shape(shape_constraint=shape_constraint)), + ] diff --git a/src/oqd_dataschema/folder.py b/src/oqd_dataschema/folder.py index 2479176..96b03e9 100644 --- a/src/oqd_dataschema/folder.py +++ b/src/oqd_dataschema/folder.py @@ -14,10 +14,11 @@ from types import MappingProxyType -from typing import Any, Dict, Literal, Optional, Tuple, Union +from typing import Annotated, Any, Dict, Literal, Optional, Tuple, Union import numpy as np from pydantic import ( + BeforeValidator, ConfigDict, Field, field_validator, @@ -30,9 +31,7 @@ ######################################################################################## -__all__ = [ - "Folder", -] +__all__ = ["Folder", "CastFolder"] ######################################################################################## @@ -221,3 +220,17 @@ def _handle_data_load(self, data): np_dtype = self._load_dtype_bytes_to_str(self.document_schema, data.dtype) return data.astype(np_dtype) + + @classmethod + def cast(cls, data): + if isinstance(data, np.ndarray): + if not isinstance(data.dtype.fields, MappingProxyType): + raise TypeError("dtype of data must be a structured dtype.") + + document_schema = cls._get_document_schema_from_dtype(data.dtype) + + return cls(document_schema=document_schema, data=data) + return data + + +CastFolder = Annotated[Folder, BeforeValidator(Folder.cast)] From a0da48f34d903d546bd5f9b95edc437795d8155e Mon Sep 17 00:00:00 2001 From: yhteoh Date: Tue, 7 Oct 2025 14:11:05 -0400 Subject: [PATCH 31/48] [feat] Added helper functions unstructured_to_structured and dict_to_structured to convert unstructured numpy array and dict of numpy arrays to a structured numpy array --- src/oqd_dataschema/__init__.py | 3 ++ src/oqd_dataschema/utils.py | 73 ++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+) diff --git a/src/oqd_dataschema/__init__.py b/src/oqd_dataschema/__init__.py index da9b007..23ab782 100644 --- a/src/oqd_dataschema/__init__.py +++ b/src/oqd_dataschema/__init__.py @@ -25,6 +25,7 @@ SinaraRawDataGroup, ) from .table import CastTable, Table +from .utils import dict_to_structured, unstructured_to_structured ######################################################################################## @@ -45,4 +46,6 @@ "Folder", "CastFolder", "confolder", + "dict_to_structured", + "unstructured_to_structured", ] diff --git a/src/oqd_dataschema/utils.py b/src/oqd_dataschema/utils.py index 57a7747..2ab92c8 100644 --- a/src/oqd_dataschema/utils.py +++ b/src/oqd_dataschema/utils.py @@ -13,6 +13,10 @@ # limitations under the License. from functools import reduce +from types import MappingProxyType + +import numpy as np +from numpy.lib import recfunctions as rfn ######################################################################################## @@ -22,6 +26,75 @@ ######################################################################################## +def _unstructured_to_structured_helper(new_data, data, dtype, counter=0): + for k, (v, _) in dtype.fields.items(): + if isinstance(v.fields, MappingProxyType): + _unstructured_to_structured_helper(new_data[k], data, v, counter=counter) + counter += len(rfn.flatten_descr(v)) + + continue + + new_data[k] = data[..., counter].astype(v) + counter += 1 + + return new_data + + +def unstructured_to_structured(data, dtype): + leaves = len(rfn.flatten_descr(dtype)) + if data.shape[-1] != leaves: + raise ValueError( + f"Incompatible shape, last dimension of data ({data.shape[-1]}) must match number of leaves in structured dtype ({leaves})." + ) + + new_data = np.empty(data.shape[:-1], dtype=dtype) + _unstructured_to_structured_helper(new_data, data, dtype) + + return new_data + + +def _dtype_from_dict(data): + np_dtype = [] + + for k, v in data.items(): + if isinstance(v, dict): + dt = _dtype_from_dict(v) + else: + dt = v.dtype + + np_dtype.append((k, dt)) + + return np.dtype(np_dtype) + + +def _dict_to_structured_helper(new_data, data, dtype): + for k, (v, _) in dtype.fields.items(): + if isinstance(v.fields, MappingProxyType): + _dict_to_structured_helper(new_data[k], data[k], v) + continue + + new_data[k] = data[k] + return new_data + + +def dict_to_structured(data): + data_dtype = _dtype_from_dict(data) + + print(rfn.get_names(data_dtype)) + + example_data = data + key = rfn.get_names(data_dtype)[0] + while isinstance(key, tuple): + example_data = example_data[key[0]] + key = key[1][0] + example_data = example_data[key] + + new_data = np.empty(example_data.shape, dtype=data_dtype) + _dict_to_structured_helper(new_data, data, dtype=data_dtype) + + return new_data + + def _flex_shape_equal(shape1, shape2): """Helper function for comparing concrete and flex shapes.""" return len(shape1) == len(shape2) and reduce( From d75e9ed7e9808a297f3bf0eb5cb9dd73f806e584 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Tue, 7 Oct 2025 22:16:39 -0400 Subject: [PATCH 32/48] [clean] remove debugging print statement --- src/oqd_dataschema/utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/oqd_dataschema/utils.py b/src/oqd_dataschema/utils.py index 2ab92c8..059b5ec 100644 --- a/src/oqd_dataschema/utils.py +++ b/src/oqd_dataschema/utils.py @@ -80,8 +80,6 @@ def _dict_to_structured_helper(new_data, data, dtype): def dict_to_structured(data): data_dtype = _dtype_from_dict(data) - print(rfn.get_names(data_dtype)) - example_data = data key = rfn.get_names(data_dtype)[0] while isinstance(key, tuple): From 8129d6e109b08e6ebfbf1f29932b82ecc7c0ce70 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Tue, 7 Oct 2025 23:12:45 -0400 Subject: [PATCH 33/48] [fix] unstructured_to_structured infers the size str dtype from the data --- src/oqd_dataschema/utils.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/src/oqd_dataschema/utils.py b/src/oqd_dataschema/utils.py index 059b5ec..191dec4 100644 --- a/src/oqd_dataschema/utils.py +++ b/src/oqd_dataschema/utils.py @@ -26,16 +26,27 @@ ######################################################################################## -def _unstructured_to_structured_helper(new_data, data, dtype, counter=0): - for k, (v, _) in dtype.fields.items(): +def _unstructured_to_structured_helper(data, dtype, counter=0): + for n, (k, (v, _)) in enumerate(dtype.fields.items()): if isinstance(v.fields, MappingProxyType): - _unstructured_to_structured_helper(new_data[k], data, v, counter=counter) + x = _unstructured_to_structured_helper(data, v, counter=counter) counter += len(rfn.flatten_descr(v)) - continue - - new_data[k] = data[..., counter].astype(v) - counter += 1 + else: + x = data[..., counter].astype(type(v)) + x = x.astype( + np.dtype( + [ + (k, x.dtype), + ] + ) + ) + counter += 1 + + if n == 0: + new_data = x + else: + new_data = rfn.append_fields(new_data, k, x, usemask=False) return new_data @@ -47,8 +58,7 @@ def unstructured_to_structured(data, dtype): f"Incompatible shape, last dimension of data ({data.shape[-1]}) must match number of leaves in structured dtype ({leaves})." ) - new_data = np.empty(data.shape[:-1], dtype=dtype) - _unstructured_to_structured_helper(new_data, data, dtype) + new_data = _unstructured_to_structured_helper(data, dtype) return new_data From fbb1dcac586bfdf5c9aa51f00e4e1f3b6db1af1d Mon Sep 17 00:00:00 2001 From: yhteoh Date: Tue, 7 Oct 2025 23:59:34 -0400 Subject: [PATCH 34/48] [fix] use append_fields in dict_to_structured and also fixed append_fields for arbitrary shaped data --- src/oqd_dataschema/utils.py | 70 +++++++++++++++++++++++-------------- 1 file changed, 43 insertions(+), 27 deletions(-) diff --git a/src/oqd_dataschema/utils.py b/src/oqd_dataschema/utils.py index 191dec4..c390a3f 100644 --- a/src/oqd_dataschema/utils.py +++ b/src/oqd_dataschema/utils.py @@ -26,34 +26,40 @@ ######################################################################################## -def _unstructured_to_structured_helper(data, dtype, counter=0): +def _unstructured_to_structured_helper(data, dtype): for n, (k, (v, _)) in enumerate(dtype.fields.items()): if isinstance(v.fields, MappingProxyType): - x = _unstructured_to_structured_helper(data, v, counter=counter) - counter += len(rfn.flatten_descr(v)) + x = _unstructured_to_structured_helper(data, v) else: - x = data[..., counter].astype(type(v)) - x = x.astype( + x = data.pop(0).astype(type(v)) + + if n == 0: + new_data = x.astype( np.dtype( [ (k, x.dtype), ] ) ) - counter += 1 - - if n == 0: - new_data = x else: - new_data = rfn.append_fields(new_data, k, x, usemask=False) + if new_data.shape != x.shape: + raise ValueError( + f"Incompatible shape, expected {new_data.shape} but got {x.shape}." + ) - return new_data + new_data = rfn.append_fields( + new_data.flatten(), k, x.flatten(), usemask=False + ).reshape(x.shape) + + return new_data.view(np.recarray) def unstructured_to_structured(data, dtype): + data = list(np.moveaxis(data, -1, 0)) + leaves = len(rfn.flatten_descr(dtype)) - if data.shape[-1] != leaves: + if len(data) != leaves: raise ValueError( f"Incompatible shape, last dimension of data ({data.shape[-1]}) must match number of leaves in structured dtype ({leaves})." ) @@ -77,28 +83,38 @@ def _dtype_from_dict(data): return np.dtype(np_dtype) -def _dict_to_structured_helper(new_data, data, dtype): - for k, (v, _) in dtype.fields.items(): +def _dict_to_structured_helper(data, dtype): + for n, (k, (v, _)) in enumerate(dtype.fields.items()): if isinstance(v.fields, MappingProxyType): - _dict_to_structured_helper(new_data[k], data[k], v) - continue + x = _dict_to_structured_helper(data[k], v) + else: + x = data[k] - new_data[k] = data[k] - return new_data + if n == 0: + new_data = x.astype( + np.dtype( + [ + (k, x.dtype), + ] + ) + ) + else: + if new_data.shape != x.shape: + raise ValueError( + f"Incompatible shape, expected {new_data.shape} but got {x.shape}." + ) + + new_data = rfn.append_fields( + new_data.flatten(), k, x.flatten(), usemask=False + ).reshape(x.shape) + + return new_data.view(np.recarray) def dict_to_structured(data): data_dtype = _dtype_from_dict(data) - example_data = data - key = rfn.get_names(data_dtype)[0] - while isinstance(key, tuple): - example_data = example_data[key[0]] - key = key[1][0] - example_data = example_data[key] - - new_data = np.empty(example_data.shape, dtype=data_dtype) - _dict_to_structured_helper(new_data, data, dtype=data_dtype) + new_data = _dict_to_structured_helper(data, dtype=data_dtype) return new_data From 7cb8806771bfc6b9e5ec002f0eec8a20db5224fb Mon Sep 17 00:00:00 2001 From: yhteoh Date: Wed, 8 Oct 2025 10:23:25 -0400 Subject: [PATCH 35/48] [format] Formatted code --- src/oqd_dataschema/utils.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/oqd_dataschema/utils.py b/src/oqd_dataschema/utils.py index c390a3f..fc30253 100644 --- a/src/oqd_dataschema/utils.py +++ b/src/oqd_dataschema/utils.py @@ -35,13 +35,7 @@ def _unstructured_to_structured_helper(data, dtype): x = data.pop(0).astype(type(v)) if n == 0: - new_data = x.astype( - np.dtype( - [ - (k, x.dtype), - ] - ) - ) + new_data = x.astype(np.dtype([(k, x.dtype)])) else: if new_data.shape != x.shape: raise ValueError( @@ -69,6 +63,9 @@ def unstructured_to_structured(data, dtype): return new_data +######################################################################################## + + def _dtype_from_dict(data): np_dtype = [] @@ -91,13 +88,7 @@ def _dict_to_structured_helper(data, dtype): x = data[k] if n == 0: - new_data = x.astype( - np.dtype( - [ - (k, x.dtype), - ] - ) - ) + new_data = x.astype(np.dtype([(k, x.dtype)])) else: if new_data.shape != x.shape: raise ValueError( @@ -113,12 +104,13 @@ def _dict_to_structured_helper(data, dtype): def dict_to_structured(data): data_dtype = _dtype_from_dict(data) - new_data = _dict_to_structured_helper(data, dtype=data_dtype) - return new_data +######################################################################################## + + def _flex_shape_equal(shape1, shape2): """Helper function for comparing concrete and flex shapes.""" return len(shape1) == len(shape2) and reduce( @@ -130,6 +122,9 @@ def _flex_shape_equal(shape1, shape2): ) +######################################################################################## + + def _validator_from_condition(f): """Helper decorator for turning a condition into a validation.""" @@ -143,6 +138,9 @@ def _wrapped_condition(model): return _wrapped_validator +######################################################################################## + + def _is_list_unique(data): seen = set() duplicates = set() From ba51d01515f4cf579e242d0be130c26bee764165 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Tue, 14 Oct 2025 12:23:07 -0400 Subject: [PATCH 36/48] [gitignore] updated gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 6b99cc7..fb5db85 100644 --- a/.gitignore +++ b/.gitignore @@ -175,3 +175,4 @@ cython_debug/ *.code-workspace .pre-commit-config.yaml _scripts +.vscode From d261e2cc51bbb76306e58c499fd5cc000d7f5539 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Tue, 14 Oct 2025 13:07:37 -0400 Subject: [PATCH 37/48] [docs, format] Updated documentation, docstrings and default field initializations --- docs/api/base.md | 10 + docs/api/datastore.md | 9 + docs/{api.md => api/group.md} | 23 +- docs/api/groupfield.md | 51 ++++ mkdocs.yaml | 9 +- src/oqd_dataschema/base.py | 27 ++- src/oqd_dataschema/constrained.py | 4 +- src/oqd_dataschema/dataset.py | 13 +- src/oqd_dataschema/datastore.py | 5 +- src/oqd_dataschema/folder.py | 49 +++- src/oqd_dataschema/group.py | 16 +- src/oqd_dataschema/table.py | 43 +++- tests/test_table.py | 372 ++++++++++++++++++++++++++++++ 13 files changed, 577 insertions(+), 54 deletions(-) create mode 100644 docs/api/base.md create mode 100644 docs/api/datastore.md rename docs/{api.md => api/group.md} (50%) create mode 100644 docs/api/groupfield.md create mode 100644 tests/test_table.py diff --git a/docs/api/base.md b/docs/api/base.md new file mode 100644 index 0000000..a85c83d --- /dev/null +++ b/docs/api/base.md @@ -0,0 +1,10 @@ +## Base Objects + + +::: oqd_dataschema.base + options: + heading_level: 3 + members: [ + "AttrKey", + "Attrs", + ] diff --git a/docs/api/datastore.md b/docs/api/datastore.md new file mode 100644 index 0000000..63add4c --- /dev/null +++ b/docs/api/datastore.md @@ -0,0 +1,9 @@ +## Datastore + + +::: oqd_dataschema.datastore + options: + heading_level: 3 + members: [ + "Datastore", + ] diff --git a/docs/api.md b/docs/api/group.md similarity index 50% rename from docs/api.md rename to docs/api/group.md index 0aca668..9ed4174 100644 --- a/docs/api.md +++ b/docs/api/group.md @@ -1,29 +1,16 @@ -## Datastore - - -::: oqd_dataschema.datastore - options: - heading_level: 3 - members: [ - "Datastore", - ] - - -## Base HDF5 Objects - -::: oqd_dataschema.base +::: oqd_dataschema.group options: heading_level: 3 members: [ - "Group", - "Dataset", + "GroupBase", + "GroupRegistry", ] -## Specified Groups +## Concrete groups -::: oqd_dataschema.groups +::: oqd_dataschema.group options: heading_level: 3 members: [ diff --git a/docs/api/groupfield.md b/docs/api/groupfield.md new file mode 100644 index 0000000..dea5d5b --- /dev/null +++ b/docs/api/groupfield.md @@ -0,0 +1,51 @@ +## Group Field + + +::: oqd_dataschema.base + options: + heading_level: 3 + members: [ + "GroupField", + ] + +## Dataset + + +::: oqd_dataschema.dataset + options: + heading_level: 3 + members: [ + "Dataset", + ] + +## Table + + +::: oqd_dataschema.table + options: + heading_level: 3 + members: [ + "Table", + ] + +## Folder + + +::: oqd_dataschema.folder + options: + heading_level: 3 + members: [ + "Folder", + ] + +## Constrained Group Fields + + +::: oqd_dataschema.constrained + options: + heading_level: 3 + members: [ + "condataset", + "contable", + "confolder", + ] diff --git a/mkdocs.yaml b/mkdocs.yaml index fca9c52..ea016d6 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -20,7 +20,11 @@ nav: - Get Started: index.md - Tutorial: tutorial.md # - Explanation: explanation.md - - API Reference: api.md + - API Reference: + - Base: api/base.md + - Group Field: api/groupfield.md + - Group: api/group.md + - Datastore: api/datastore.md theme: name: material @@ -59,6 +63,7 @@ theme: - toc.follow plugins: + - search - mkdocstrings: handlers: python: @@ -79,7 +84,7 @@ plugins: separate_signature: false group_by_category: true members_order: "source" - import: + inventories: - https://docs.python.org/3/objects.inv - https://docs.pydantic.dev/latest/objects.inv - https://pandas.pydata.org/docs/objects.inv diff --git a/src/oqd_dataschema/base.py b/src/oqd_dataschema/base.py index 0dd9098..8a1bd8e 100644 --- a/src/oqd_dataschema/base.py +++ b/src/oqd_dataschema/base.py @@ -16,12 +16,13 @@ import typing from abc import ABC, abstractmethod from enum import Enum -from typing import Annotated, Optional, Union +from typing import Annotated, Union import numpy as np from pydantic import ( BaseModel, BeforeValidator, + Field, ) ######################################################################################## @@ -69,18 +70,28 @@ def _valid_attr_key(value): return value -Attrs = Optional[ - dict[ - Annotated[str, BeforeValidator(_valid_attr_key)], - Union[int, float, str, complex], - ] -] +AttrKey = Annotated[str, BeforeValidator(_valid_attr_key)] +""" +Annotated type that represents a valid key for attributes (prevents overwriting of protected attrs). +""" + +Attrs = dict[AttrKey, Union[int, float, str, complex]] +""" +Type that represents attributes of an object. +""" ######################################################################################## class GroupField(BaseModel, ABC): - attrs: Attrs + """ + Abstract class for a valid data field of Group. + + Attributes: + attrs: A dictionary of attributes to append to the object. + """ + + attrs: Attrs = Field(default_factory=lambda: {}) @classmethod def _is_supported_type(cls, type_): diff --git a/src/oqd_dataschema/constrained.py b/src/oqd_dataschema/constrained.py index 871d190..b2cf8d6 100644 --- a/src/oqd_dataschema/constrained.py +++ b/src/oqd_dataschema/constrained.py @@ -158,7 +158,7 @@ def contable( min_dim=None, max_dim=None, ): - """Implements dtype, dimension and shape constrains on the Table.""" + """Implements field, dtype, dimension and shape constrains on the Table.""" return Annotated[ CastTable, AfterValidator( @@ -181,7 +181,7 @@ def confolder( min_dim=None, max_dim=None, ): - """Implements dtype, dimension and shape constrains on the Folder.""" + """Implements dimension and shape constrains on the Folder.""" return Annotated[ Folder, AfterValidator(_constraint_dim(min_dim=min_dim, max_dim=max_dim)), diff --git a/src/oqd_dataschema/dataset.py b/src/oqd_dataschema/dataset.py index 1afa6ca..678ae90 100644 --- a/src/oqd_dataschema/dataset.py +++ b/src/oqd_dataschema/dataset.py @@ -13,6 +13,9 @@ # limitations under the License. # %% + +from __future__ import annotations + from typing import Annotated, Any, Literal, Optional, Tuple, Union import numpy as np @@ -63,7 +66,7 @@ class Dataset(GroupField, extra="forbid"): shape: Optional[Tuple[Union[int, None], ...]] = None data: Optional[Any] = Field(default=None, exclude=True) - attrs: Attrs = {} + attrs: Attrs = Field(default_factory=lambda: {}) model_config = ConfigDict( use_enum_values=False, arbitrary_types_allowed=True, validate_assignment=True @@ -71,7 +74,7 @@ class Dataset(GroupField, extra="forbid"): @field_validator("data", mode="before") @classmethod - def validate_and_update(cls, value): + def _validate_and_update(cls, value): # check if data exist if value is None: return value @@ -83,7 +86,7 @@ def validate_and_update(cls, value): return value @model_validator(mode="after") - def validate_data_matches_shape_dtype(self): + def _validate_data_matches_shape_dtype(self): """Ensure that `data` matches `dtype` and `shape`.""" # check if data exist @@ -116,7 +119,8 @@ def validate_data_matches_shape_dtype(self): return self @classmethod - def cast(cls, data): + def cast(cls, data: np.ndarray) -> Dataset: + """Casts data from numpy array to Dataset.""" if isinstance(data, np.ndarray): return cls(data=data) return data @@ -140,3 +144,4 @@ def _handle_data_load(self, data): CastDataset = Annotated[Dataset, BeforeValidator(Dataset.cast)] +"""Annotated type that automatically executes Dataset.cast""" diff --git a/src/oqd_dataschema/datastore.py b/src/oqd_dataschema/datastore.py index 88b5760..fd3660d 100644 --- a/src/oqd_dataschema/datastore.py +++ b/src/oqd_dataschema/datastore.py @@ -21,6 +21,7 @@ import h5py from pydantic import ( BaseModel, + Field, field_validator, ) @@ -44,9 +45,9 @@ class Datastore(BaseModel, extra="forbid"): attrs (Attrs): attributes of the datastore. """ - groups: Dict[str, Any] = {} + groups: Dict[str, Any] = Field(default_factory=lambda: {}) - attrs: Attrs = {} + attrs: Attrs = Field(default_factory=lambda: {}) @classmethod def _validate_group(cls, key, group): diff --git a/src/oqd_dataschema/folder.py b/src/oqd_dataschema/folder.py index 96b03e9..c693673 100644 --- a/src/oqd_dataschema/folder.py +++ b/src/oqd_dataschema/folder.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations from types import MappingProxyType from typing import Annotated, Any, Dict, Literal, Optional, Tuple, Union @@ -42,11 +43,47 @@ class Folder(GroupField, extra="forbid"): + """ + Schema representation for a table object to be saved within an HDF5 file. + + Attributes: + document_schema: The schema for a document (structured type with keys and their datatype). Types are inferred from the `data` attribute if not provided. + shape: The shape of the folder. + data: The numpy ndarray or recarray (of structured dtype) of the data, from which `dtype` and `shape` can be inferred. + + attrs: A dictionary of attributes to append to the folder. + + Example: + ```python + schema = dict( + index="int32", + t="float64", + channels=dict(ch1="complex128", ch2="complex128"), + label="str", + ) + dt = np.dtype( + [ + ("index", np.int32), + ("t", np.float64), + ("channels", np.dtype([("ch1", np.complex128), ("ch2", np.complex128)])), + ("label", np.dtype(" np.dtype: return self._numpy_dtype( self.document_schema, str_size=str_size, bytes_size=bytes_size ) @@ -222,7 +259,8 @@ def _handle_data_load(self, data): return data.astype(np_dtype) @classmethod - def cast(cls, data): + def cast(cls, data: np.ndarray) -> Folder: + """Casts data from numpy structured array to Folder.""" if isinstance(data, np.ndarray): if not isinstance(data.dtype.fields, MappingProxyType): raise TypeError("dtype of data must be a structured dtype.") @@ -234,3 +272,4 @@ def cast(cls, data): CastFolder = Annotated[Folder, BeforeValidator(Folder.cast)] +"""Annotated type that automatically executes Folder.cast""" diff --git a/src/oqd_dataschema/group.py b/src/oqd_dataschema/group.py index 2217b77..fcaa558 100644 --- a/src/oqd_dataschema/group.py +++ b/src/oqd_dataschema/group.py @@ -21,6 +21,7 @@ from pydantic import ( BaseModel, Discriminator, + Field, TypeAdapter, ) @@ -46,19 +47,20 @@ class GroupBase(BaseModel, extra="forbid"): """ Schema representation for a group object within an HDF5 file. - Each grouping of data should be defined as a subclass of `Group`, and specify the datasets that it will contain. + Each grouping of data should be defined as a subclass of `GroupBase`, and specify the datasets that it will contain. This base object only has attributes, `attrs`, which are associated to the HDF5 group. Attributes: - attrs: A dictionary of attributes to append to the dataset. + attrs: A dictionary of attributes to append to the group. - Example: - ``` - group = Group(attrs={'version': 2, 'date': '2025-01-01'}) - ``` """ - attrs: Attrs = {} + attrs: Attrs = Field(default_factory=lambda: {}) + + def __new__(cls, *args, **kwargs): + if cls is GroupBase: + raise TypeError(f"only subclasses of '{cls.__name__}' may be instantiated") + return object.__new__(cls, *args, **kwargs) @staticmethod def _is_basic_groupfield_type(v): diff --git a/src/oqd_dataschema/table.py b/src/oqd_dataschema/table.py index fc5a8da..7664c43 100644 --- a/src/oqd_dataschema/table.py +++ b/src/oqd_dataschema/table.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations from types import MappingProxyType from typing import Annotated, Any, List, Literal, Optional, Tuple, Union @@ -46,11 +47,38 @@ class Table(GroupField, extra="forbid"): - columns: List[Column] # type: ignore + """ + Schema representation for a table object to be saved within an HDF5 file. + + Attributes: + columns: The columns in the table accompanied by their datatype. Types are inferred from the `data` attribute if not provided. + shape: The shape of the table (excludes the column index). + data: The numpy ndarray or recarray (of structured dtype) of the data, from which `dtype` and `shape` can be inferred. + + attrs: A dictionary of attributes to append to the table. + + Example: + ```python + dt = np.dtype( + [ + ("index", np.int32), + ("t", np.float64), + ("z", np.complex128), + ("label", np.dtype(" pd.DataFrame: + """Converts flat table to pandas DataFrame.""" if len(self.shape) > 1: raise ValueError( "Conversion to pandas DataFrame only supported on 1D Table." @@ -104,7 +133,7 @@ def _pd_to_np(df): @field_validator("data", mode="before") @classmethod - def validate_and_update(cls, value): + def _validate_and_update(cls, value): # check if data exist if value is None: return value @@ -125,7 +154,7 @@ def validate_and_update(cls, value): return value @model_validator(mode="after") - def validate_data_matches_shape_dtype(self): + def _validate_data_matches_shape_dtype(self): """Ensure that `data` matches `dtype` and `shape`.""" # check if data exist @@ -185,7 +214,8 @@ def numpy_dtype(self, *, str_size=64, bytes_size=64): return np.dtype(np_dtype) @classmethod - def cast(cls, data): + def cast(cls, data: np.ndarray | pd.DataFrame) -> Table: + """Casts data from pandas DataFrame or numpy structured array to Table.""" if isinstance(data, pd.DataFrame): data = cls._pd_to_np(data) @@ -229,3 +259,4 @@ def _handle_data_load(self, data): CastTable = Annotated[Table, BeforeValidator(Table.cast)] +"""Annotated type that automatically executes Table.cast""" diff --git a/tests/test_table.py b/tests/test_table.py new file mode 100644 index 0000000..ca3d3b0 --- /dev/null +++ b/tests/test_table.py @@ -0,0 +1,372 @@ +# Copyright 2024-2025 Open Quantum Design + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% + +import numpy as np +import pytest + +from oqd_dataschema import Table +from oqd_dataschema.base import DTypes + +######################################################################################## + + +class TestTable: + def test_empty_table(self): + Table(columns=[], shape=(100,)) + + @pytest.mark.xfail(raises=ValueError) + @pytest.mark.parametrize( + ("column"), + [ + [("c1", "bool"), ("c1", "int16")], + [ + ("c1", "str"), + ("c2", "int16"), + ("c3", "float64"), + ("c1", "complex128"), + ], + ], + ) + def test_duplicate_column(self, column): + Table(columns=column, shape=(100,)) + + +class TestTableDType: + @pytest.mark.parametrize( + ("dtype", "np_dtype"), + [ + ("bool", np.dtypes.BoolDType()), + ("int16", np.dtypes.Int16DType()), + ("int32", np.dtypes.Int32DType()), + ("int64", np.dtypes.Int64DType()), + ("uint16", np.dtypes.UInt16DType()), + ("uint32", np.dtypes.UInt32DType()), + ("uint64", np.dtypes.UInt64DType()), + ("float16", np.dtypes.Float16DType()), + ("float32", np.dtypes.Float32DType()), + ("float64", np.dtypes.Float64DType()), + ("complex64", np.dtypes.Complex64DType()), + ("complex128", np.dtypes.Complex128DType()), + ("str", np.dtypes.StrDType(16)), + ("bytes", np.dtypes.BytesDType(16)), + ], + ) + def test_dtypes(self, dtype, np_dtype): + tbl = Table(columns=[("c", dtype)], shape=(100,)) + + data = np.rec.fromarrays( + np.random.rand(1, 100), + dtype=np.dtype( + [ + ("c", np_dtype), + ] + ), + ) + tbl.data = data + + @pytest.mark.parametrize( + ("column", "np_dtype"), + [ + ( + [("c1", "bool"), ("c2", "int16")], + np.dtype( + [("c1", np.dtypes.BoolDType()), ("c2", np.dtypes.Int16DType())] + ), + ), + ( + [ + ("c1", "str"), + ("c2", "int16"), + ("c3", "float64"), + ("c4", "complex128"), + ], + np.dtype( + [ + ("c1", np.dtypes.StrDType(16)), + ("c2", np.dtypes.Int16DType()), + ("c3", np.dtypes.Float64DType()), + ("c4", np.dtypes.Complex128DType()), + ] + ), + ), + ], + ) + def test_multi_column_dtypes(self, column, np_dtype): + tbl = Table(columns=column, shape=(100,)) + + data = np.rec.fromarrays(np.random.rand(len(column), 100), dtype=np_dtype) + tbl.data = data + + @pytest.mark.xfail(raises=ValueError) + @pytest.mark.parametrize( + "dtype", + [ + "bool", + "int16", + "int32", + "int64", + "uint16", + "uint32", + "uint64", + "float16", + "float32", + "float64", + "complex64", + "complex128", + "str", + "bytes", + ], + ) + def test_unmatched_dtype_data(self, dtype): + tbl = Table(columns=[("c", dtype)], shape=(100,)) + + data = np.rec.fromarrays( + np.random.rand(1, 100), + dtype=np.dtype( + [ + ("c", "O"), + ] + ), + ) + tbl.data = data + + @pytest.mark.parametrize( + "np_dtype", + [ + np.dtypes.BoolDType(), + np.dtypes.Int16DType(), + np.dtypes.Int32DType(), + np.dtypes.Int64DType(), + np.dtypes.UInt16DType(), + np.dtypes.UInt32DType(), + np.dtypes.UInt64DType(), + np.dtypes.Float16DType(), + np.dtypes.Float32DType(), + np.dtypes.Float64DType(), + np.dtypes.Complex64DType(), + np.dtypes.Complex128DType(), + np.dtypes.StrDType(16), + np.dtypes.BytesDType(16), + ], + ) + def test_flexible_dtype(self, np_dtype): + tbl = Table(columns=[("c", None)], shape=(100,)) + + data = np.rec.fromarrays( + np.random.rand(1, 100), + dtype=np.dtype( + [ + ("c", np_dtype), + ] + ), + ) + tbl.data = data + + assert ( + dict(tbl.columns)["c"] + == DTypes(type(tbl.data.dtype.fields["c"][0])).name.lower() + ) + + def test_dtype_mutation(self): + tbl = Table(columns=[("c", "float32")], shape=(100,)) + + tbl.columns[0] = ("c", "float64") + + data = np.rec.fromarrays( + np.random.rand(1, 100), + dtype=np.dtype( + [ + ("c", "float64"), + ] + ), + ) + tbl.data = data + + +# class TestDatasetShape: +# @pytest.mark.xfail(raises=ValueError) +# @pytest.mark.parametrize( +# ("shape", "data_shape"), +# [ +# ((0,), (100,)), +# ((1,), (100,)), +# ((99,), (100,)), +# ((1, 1), (100,)), +# ((100, None), (100,)), +# ((None, None), (100,)), +# ((None, 100), (100,)), +# ], +# ) +# def test_unmatched_shape_data(self, shape, data_shape): +# ds = Dataset(dtype="float64", shape=shape) + +# data = np.random.rand(*data_shape) +# ds.data = data + +# @pytest.mark.parametrize( +# ("shape", "data_shape"), +# [ +# ((None,), (0,)), +# ((None,), (1,)), +# ((None,), (100,)), +# ((None, 0), (0, 0)), +# ((None, 1), (1, 1)), +# ((None, None), (1, 1)), +# ((None, None), (10, 100)), +# ((None, None, 1), (1, 1, 1)), +# ], +# ) +# def test_flexible_shape(self, shape, data_shape): +# ds = Dataset(dtype="float64", shape=shape) + +# data = np.random.rand(*data_shape) +# ds.data = data + +# assert ds.shape == ds.data.shape + +# def test_shape_mutation(self): +# ds = Dataset(dtype="float64", shape=(1,)) + +# ds.shape = (100,) + +# data = np.random.rand(100) +# ds.data = data + + +# class TestCastDataset: +# @pytest.fixture +# def adapter(self): +# return TypeAdapter(CastDataset) + +# @pytest.mark.parametrize( +# ("data", "dtype", "shape"), +# [ +# (np.random.rand(100), "float64", (100,)), +# (np.random.rand(10).astype("str"), "str", (10,)), +# (np.random.rand(1, 10, 100).astype("bytes"), "bytes", (1, 10, 100)), +# ], +# ) +# def test_cast(self, adapter, data, shape, dtype): +# ds = adapter.validate_python(data) + +# assert ds.shape == shape and ds.dtype == dtype + + +# class TestConstrainedDataset: +# @pytest.mark.parametrize( +# ("cds", "data"), +# [ +# (condataset(dtype_constraint="float64"), np.random.rand(10)), +# (condataset(dtype_constraint="str"), np.random.rand(10).astype(str)), +# ( +# condataset(dtype_constraint=("float16", "float32", "float64")), +# np.random.rand(10), +# ), +# ( +# condataset(dtype_constraint=("float16", "float32", "float64")), +# np.random.rand(10).astype("float16"), +# ), +# ( +# condataset(dtype_constraint=("float16", "float32", "float64")), +# np.random.rand(10).astype("float32"), +# ), +# ], +# ) +# def test_constrained_dataset_dtype(self, cds, data): +# adapter = TypeAdapter(cds) + +# adapter.validate_python(data) + +# @pytest.mark.xfail(raises=ValueError) +# @pytest.mark.parametrize( +# ("cds", "data"), +# [ +# (condataset(dtype_constraint="float64"), np.random.rand(10).astype(str)), +# (condataset(dtype_constraint="str"), np.random.rand(10)), +# ( +# condataset(dtype_constraint=("float16", "float32", "float64")), +# np.random.rand(10).astype(str), +# ), +# ], +# ) +# def test_violate_dtype_constraint(self, cds, data): +# adapter = TypeAdapter(cds) + +# adapter.validate_python(data) + +# @pytest.mark.parametrize( +# ("cds", "data"), +# [ +# (condataset(min_dim=1, max_dim=1), np.random.rand(10)), +# (condataset(min_dim=0, max_dim=1), np.random.rand(10)), +# (condataset(max_dim=2), np.random.rand(10)), +# (condataset(max_dim=3), np.random.rand(10, 10, 10)), +# (condataset(min_dim=2), np.random.rand(10, 10)), +# (condataset(min_dim=2), np.random.rand(10, 10, 10, 10, 10)), +# (condataset(min_dim=2, max_dim=4), np.random.rand(10, 10, 10, 10)), +# (condataset(min_dim=2, max_dim=4), np.random.rand(10, 10, 10)), +# (condataset(min_dim=2, max_dim=4), np.random.rand(10, 10)), +# ], +# ) +# def test_constrained_dataset_dimension(self, cds, data): +# adapter = TypeAdapter(cds) + +# adapter.validate_python(data) + +# @pytest.mark.xfail(raises=ValueError) +# @pytest.mark.parametrize( +# ("cds", "data"), +# [ +# (condataset(min_dim=1, max_dim=1), np.random.rand(10, 10)), +# (condataset(min_dim=2, max_dim=3), np.random.rand(10)), +# (condataset(min_dim=2, max_dim=3), np.random.rand(10, 10, 10, 10)), +# ], +# ) +# def test_violate_dimension_constraint(self, cds, data): +# adapter = TypeAdapter(cds) + +# adapter.validate_python(data) + +# @pytest.mark.parametrize( +# ("cds", "data"), +# [ +# (condataset(shape_constraint=(None,)), np.random.rand(10)), +# (condataset(shape_constraint=(10,)), np.random.rand(10)), +# (condataset(shape_constraint=(None, None)), np.random.rand(1, 2)), +# (condataset(shape_constraint=(1, None)), np.random.rand(1, 2)), +# (condataset(shape_constraint=(1, 2)), np.random.rand(1, 2)), +# (condataset(shape_constraint=(1, None, 3)), np.random.rand(1, 10, 3)), +# ], +# ) +# def test_constrained_dataset_shape(self, cds, data): +# adapter = TypeAdapter(cds) + +# adapter.validate_python(data) + +# @pytest.mark.xfail(raises=ValueError) +# @pytest.mark.parametrize( +# ("cds", "data"), +# [ +# (condataset(shape_constraint=(1,)), np.random.rand(10)), +# (condataset(shape_constraint=(None,)), np.random.rand(10, 10)), +# (condataset(shape_constraint=(None, 1)), np.random.rand(10, 10)), +# (condataset(shape_constraint=(None, 1)), np.random.rand(1, 10)), +# ], +# ) +# def test_violate_shape_constraint(self, cds, data): +# adapter = TypeAdapter(cds) + +# adapter.validate_python(data) From 06f7e6f9ee69300fb7f87b192f5d2212c18e4d6c Mon Sep 17 00:00:00 2001 From: yhteoh Date: Tue, 14 Oct 2025 13:08:56 -0400 Subject: [PATCH 38/48] [docs] Updated documentation and docstrings --- docs/api/base.md | 2 +- docs/api/datastore.md | 2 -- docs/api/groupfield.md | 3 +-- docs/explanation.md | 22 ++++++++++++++++++++++ mkdocs.yaml | 2 +- src/oqd_dataschema/base.py | 6 ++++-- src/oqd_dataschema/constrained.py | 24 ++++++++++++------------ src/oqd_dataschema/group.py | 5 ----- 8 files changed, 41 insertions(+), 25 deletions(-) diff --git a/docs/api/base.md b/docs/api/base.md index a85c83d..3ccf7ae 100644 --- a/docs/api/base.md +++ b/docs/api/base.md @@ -1,4 +1,4 @@ -## Base Objects +## Attributes Types ::: oqd_dataschema.base diff --git a/docs/api/datastore.md b/docs/api/datastore.md index 63add4c..c797a9c 100644 --- a/docs/api/datastore.md +++ b/docs/api/datastore.md @@ -1,5 +1,3 @@ -## Datastore - ::: oqd_dataschema.datastore options: diff --git a/docs/api/groupfield.md b/docs/api/groupfield.md index dea5d5b..ddc1e35 100644 --- a/docs/api/groupfield.md +++ b/docs/api/groupfield.md @@ -1,8 +1,7 @@ -## Group Field - ::: oqd_dataschema.base options: + filters: [] heading_level: 3 members: [ "GroupField", diff --git a/docs/explanation.md b/docs/explanation.md index e69de29..7af1180 100644 --- a/docs/explanation.md +++ b/docs/explanation.md @@ -0,0 +1,22 @@ +## Datastore + +A [Datastore][oqd_dataschema.datastore.Datastore] represents a HDF5 file of a particular hierarchical structure. + +### Hierarchy + +``` +/ +├── group1/ +│ └── dataset1 +├── group2/ +│ ├── dataset2 +│ ├── table1 +│ └── folder1 +└── group3/ + ├── table2 + └── dataset_dict1/ + ├── dataset5 + └── dataset6 +``` + +The top level of [Datastore][oqd_dataschema.datastore.Datastore] contains multiple [Groups](api/group.md) diff --git a/mkdocs.yaml b/mkdocs.yaml index ea016d6..6153de2 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -19,7 +19,7 @@ extra: nav: - Get Started: index.md - Tutorial: tutorial.md - # - Explanation: explanation.md + - Explanation: explanation.md - API Reference: - Base: api/base.md - Group Field: api/groupfield.md diff --git a/src/oqd_dataschema/base.py b/src/oqd_dataschema/base.py index 8a1bd8e..07aed69 100644 --- a/src/oqd_dataschema/base.py +++ b/src/oqd_dataschema/base.py @@ -100,9 +100,11 @@ def _is_supported_type(cls, type_): ) @abstractmethod - def _handle_data_dump(self, data): + def _handle_data_dump(self, data: np.ndarray) -> np.ndarray: + """Hook into [Datastore.model_dump_hdf5][oqd_dataschema.datastore.Datastore.model_dump_hdf5] for compatibility mapping to HDF5.""" pass @abstractmethod - def _handle_data_load(self, data): + def _handle_data_load(self, data: np.ndarray) -> np.ndarray: + """Hook into [Datastore.model_validate_hdf5][oqd_dataschema.datastore.Datastore.model_validate_hdf5] for reversing compatibility mapping, i.e. mapping data back to original type.""" pass diff --git a/src/oqd_dataschema/constrained.py b/src/oqd_dataschema/constrained.py index b2cf8d6..666b649 100644 --- a/src/oqd_dataschema/constrained.py +++ b/src/oqd_dataschema/constrained.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Annotated, Sequence +from typing import Annotated, Sequence, TypeAlias from pydantic import AfterValidator @@ -30,7 +30,7 @@ @_validator_from_condition -def _constraint_dim(model, *, min_dim=None, max_dim=None): +def _constrain_dim(model, *, min_dim=None, max_dim=None): """Constrains the dimension of a Dataset or Table.""" if min_dim is not None and max_dim is not None and min_dim > max_dim: @@ -50,7 +50,7 @@ def _constraint_dim(model, *, min_dim=None, max_dim=None): @_validator_from_condition -def _constraint_shape(model, *, shape_constraint=None): +def _constrain_shape(model, *, shape_constraint=None): """Constrains the shape of a Dataset or Table.""" # fast escape @@ -91,13 +91,13 @@ def _constrain_dtype_dataset(dataset, *, dtype_constraint=None): def condataset( *, shape_constraint=None, dtype_constraint=None, min_dim=None, max_dim=None -): +) -> TypeAlias: """Implements dtype, dimension and shape constrains on the Dataset.""" return Annotated[ CastDataset, AfterValidator(_constrain_dtype_dataset(dtype_constraint=dtype_constraint)), - AfterValidator(_constraint_dim(min_dim=min_dim, max_dim=max_dim)), - AfterValidator(_constraint_shape(shape_constraint=shape_constraint)), + AfterValidator(_constrain_dim(min_dim=min_dim, max_dim=max_dim)), + AfterValidator(_constrain_shape(shape_constraint=shape_constraint)), ] @@ -157,7 +157,7 @@ def contable( shape_constraint=None, min_dim=None, max_dim=None, -): +) -> TypeAlias: """Implements field, dtype, dimension and shape constrains on the Table.""" return Annotated[ CastTable, @@ -167,8 +167,8 @@ def contable( ) ), AfterValidator(_constrain_dtype_table(dtype_constraint=dtype_constraint)), - AfterValidator(_constraint_dim(min_dim=min_dim, max_dim=max_dim)), - AfterValidator(_constraint_shape(shape_constraint=shape_constraint)), + AfterValidator(_constrain_dim(min_dim=min_dim, max_dim=max_dim)), + AfterValidator(_constrain_shape(shape_constraint=shape_constraint)), ] @@ -180,10 +180,10 @@ def confolder( shape_constraint=None, min_dim=None, max_dim=None, -): +) -> TypeAlias: """Implements dimension and shape constrains on the Folder.""" return Annotated[ Folder, - AfterValidator(_constraint_dim(min_dim=min_dim, max_dim=max_dim)), - AfterValidator(_constraint_shape(shape_constraint=shape_constraint)), + AfterValidator(_constrain_dim(min_dim=min_dim, max_dim=max_dim)), + AfterValidator(_constrain_shape(shape_constraint=shape_constraint)), ] diff --git a/src/oqd_dataschema/group.py b/src/oqd_dataschema/group.py index fcaa558..655050d 100644 --- a/src/oqd_dataschema/group.py +++ b/src/oqd_dataschema/group.py @@ -57,11 +57,6 @@ class GroupBase(BaseModel, extra="forbid"): attrs: Attrs = Field(default_factory=lambda: {}) - def __new__(cls, *args, **kwargs): - if cls is GroupBase: - raise TypeError(f"only subclasses of '{cls.__name__}' may be instantiated") - return object.__new__(cls, *args, **kwargs) - @staticmethod def _is_basic_groupfield_type(v): return reduce( From 0314d1a927f6efe96ceccae0e3b55eceed32ea36 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Tue, 14 Oct 2025 14:45:13 -0400 Subject: [PATCH 39/48] [clean] remove concrete group implementations --- docs/api/group.md | 12 ---------- src/oqd_dataschema/__init__.py | 3 --- src/oqd_dataschema/group.py | 42 ---------------------------------- 3 files changed, 57 deletions(-) diff --git a/docs/api/group.md b/docs/api/group.md index 9ed4174..1b8ff22 100644 --- a/docs/api/group.md +++ b/docs/api/group.md @@ -6,15 +6,3 @@ "GroupBase", "GroupRegistry", ] - -## Concrete groups - - -::: oqd_dataschema.group - options: - heading_level: 3 - members: [ - "SinaraRawDataGroup", - "MeasurementOutcomesDataGroup", - "ExpectationValueDataGroup", - ] diff --git a/src/oqd_dataschema/__init__.py b/src/oqd_dataschema/__init__.py index 23ab782..97bb78c 100644 --- a/src/oqd_dataschema/__init__.py +++ b/src/oqd_dataschema/__init__.py @@ -20,9 +20,6 @@ ExpectationValueDataGroup, GroupBase, GroupRegistry, - MeasurementOutcomesDataGroup, - OQDTestbenchDataGroup, - SinaraRawDataGroup, ) from .table import CastTable, Table from .utils import dict_to_structured, unstructured_to_structured diff --git a/src/oqd_dataschema/group.py b/src/oqd_dataschema/group.py index 655050d..685f596 100644 --- a/src/oqd_dataschema/group.py +++ b/src/oqd_dataschema/group.py @@ -26,17 +26,12 @@ ) from oqd_dataschema.base import Attrs, GroupField -from oqd_dataschema.dataset import CastDataset ######################################################################################## __all__ = [ "GroupBase", "GroupRegistry", - "SinaraRawDataGroup", - "MeasurementOutcomesDataGroup", - "ExpectationValueDataGroup", - "OQDTestbenchDataGroup", ] @@ -170,40 +165,3 @@ class GroupRegistry(metaclass=MetaGroupRegistry): """ pass - - -######################################################################################## - - -class SinaraRawDataGroup(GroupBase): - """ - Example `Group` for raw data from the Sinara real-time control system. - This is a placeholder for demonstration and development. - """ - - camera_images: CastDataset - - -class MeasurementOutcomesDataGroup(GroupBase): - """ - Example `Group` for processed data classifying the readout of the state. - This is a placeholder for demonstration and development. - """ - - outcomes: CastDataset - - -class ExpectationValueDataGroup(GroupBase): - """ - Example `Group` for processed data calculating the expectation values. - This is a placeholder for demonstration and development. - """ - - expectation_value: CastDataset - - -class OQDTestbenchDataGroup(GroupBase): - """ """ - - time: CastDataset - voltages: CastDataset From 358120b8204e2472e5a9a69f51675252b7c71889 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Tue, 14 Oct 2025 14:50:38 -0400 Subject: [PATCH 40/48] [docs] updated documentation and docstrings --- docs/api/base.md | 11 ++++++ docs/api/groupfield.md | 3 ++ docs/api/utils.md | 10 +++++ mkdocs.yaml | 1 + src/oqd_dataschema/__init__.py | 10 +---- src/oqd_dataschema/base.py | 41 ++++++++++++++++++-- src/oqd_dataschema/constrained.py | 63 +++++++++++++++++++++++++++++-- src/oqd_dataschema/dataset.py | 6 +-- src/oqd_dataschema/folder.py | 6 +-- src/oqd_dataschema/table.py | 6 +-- src/oqd_dataschema/utils.py | 5 ++- 11 files changed, 135 insertions(+), 27 deletions(-) create mode 100644 docs/api/utils.md diff --git a/docs/api/base.md b/docs/api/base.md index 3ccf7ae..33fba53 100644 --- a/docs/api/base.md +++ b/docs/api/base.md @@ -8,3 +8,14 @@ "AttrKey", "Attrs", ] + +## Data Types + + +::: oqd_dataschema.base + options: + heading_level: 3 + members: [ + "DTypes", + "DTypeNames", + ] diff --git a/docs/api/groupfield.md b/docs/api/groupfield.md index ddc1e35..af1319f 100644 --- a/docs/api/groupfield.md +++ b/docs/api/groupfield.md @@ -15,6 +15,7 @@ heading_level: 3 members: [ "Dataset", + "CastDataset", ] ## Table @@ -25,6 +26,7 @@ heading_level: 3 members: [ "Table", + "CastTable", ] ## Folder @@ -35,6 +37,7 @@ heading_level: 3 members: [ "Folder", + "CastFolder", ] ## Constrained Group Fields diff --git a/docs/api/utils.md b/docs/api/utils.md new file mode 100644 index 0000000..cff7d91 --- /dev/null +++ b/docs/api/utils.md @@ -0,0 +1,10 @@ +## Dataset + + +::: oqd_dataschema.utils + options: + heading_level: 3 + members: [ + "dict_to_structured", + "unstructured_to_structured", + ] diff --git a/mkdocs.yaml b/mkdocs.yaml index 6153de2..8981d0c 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -25,6 +25,7 @@ nav: - Group Field: api/groupfield.md - Group: api/group.md - Datastore: api/datastore.md + - Utilities: api/utils.md theme: name: material diff --git a/src/oqd_dataschema/__init__.py b/src/oqd_dataschema/__init__.py index 97bb78c..afc6f77 100644 --- a/src/oqd_dataschema/__init__.py +++ b/src/oqd_dataschema/__init__.py @@ -16,11 +16,7 @@ from .dataset import CastDataset, Dataset from .datastore import Datastore from .folder import CastFolder, Folder -from .group import ( - ExpectationValueDataGroup, - GroupBase, - GroupRegistry, -) +from .group import GroupBase, GroupRegistry from .table import CastTable, Table from .utils import dict_to_structured, unstructured_to_structured @@ -30,10 +26,6 @@ "Datastore", "GroupBase", "GroupRegistry", - "ExpectationValueDataGroup", - "MeasurementOutcomesDataGroup", - "OQDTestbenchDataGroup", - "SinaraRawDataGroup", "Dataset", "CastDataset", "condataset", diff --git a/src/oqd_dataschema/base.py b/src/oqd_dataschema/base.py index 07aed69..a25d639 100644 --- a/src/oqd_dataschema/base.py +++ b/src/oqd_dataschema/base.py @@ -13,10 +13,12 @@ # limitations under the License. # %% +from __future__ import annotations + import typing from abc import ABC, abstractmethod from enum import Enum -from typing import Annotated, Union +from typing import Annotated, Literal, Union import numpy as np from pydantic import ( @@ -27,12 +29,25 @@ ######################################################################################## -__all__ = ["Attrs", "DTypes", "GroupField"] +__all__ = ["Attrs", "DTypes", "DTypeNames", "GroupField"] ######################################################################################## class DTypes(Enum): + """ + Enum for data types supported by oqd-dataschema. + + |Type |Variant| + |-------|-------| + |Boolean|`BOOL` | + |Integer|`INT16`, `INT32`, `INT64` (signed)
`UINT16`, `UINT32`, `UINT64` (unsigned)| + |Float |`FLOAT32`, `FLOAT64`| + |Complex|`COMPLEX64`, `COMPLEX128`| + |Bytes |`BYTES`| + |String |`STR`, `STRING`| + """ + BOOL = np.dtypes.BoolDType INT16 = np.dtypes.Int16DType INT32 = np.dtypes.Int32DType @@ -50,20 +65,35 @@ class DTypes(Enum): STRING = np.dtypes.StringDType @classmethod - def get(cls, name): + def get(cls, name: str) -> DTypes: + """ + Get the [`DTypes`][oqd_dataschema.base.DTypes] enum variant by lowercase name. + """ return cls[name.upper()] @classmethod def names(cls): + """ + Get the lowercase names of all variants of [`DTypes`][oqd_dataschema.base.DTypes] enum. + """ return tuple((dtype.name.lower() for dtype in cls)) +DTypeNames = Literal[DTypes.names()] +""" +Literal list of lowercase names for [`DTypes`][oqd_dataschema.base.DTypes] variants. +""" + + ######################################################################################## invalid_attrs = ["_datastore_signature", "_group_schema"] -def _valid_attr_key(value): +def _valid_attr_key(value: str) -> str: + """ + Validates attribute keys (prevents overwriting of protected attrs). + """ if value in invalid_attrs: raise KeyError @@ -108,3 +138,6 @@ def _handle_data_dump(self, data: np.ndarray) -> np.ndarray: def _handle_data_load(self, data: np.ndarray) -> np.ndarray: """Hook into [Datastore.model_validate_hdf5][oqd_dataschema.datastore.Datastore.model_validate_hdf5] for reversing compatibility mapping, i.e. mapping data back to original type.""" pass + + +# %% diff --git a/src/oqd_dataschema/constrained.py b/src/oqd_dataschema/constrained.py index 666b649..671060c 100644 --- a/src/oqd_dataschema/constrained.py +++ b/src/oqd_dataschema/constrained.py @@ -90,9 +90,33 @@ def _constrain_dtype_dataset(dataset, *, dtype_constraint=None): def condataset( - *, shape_constraint=None, dtype_constraint=None, min_dim=None, max_dim=None + *, + shape_constraint=None, + dtype_constraint=None, + min_dim=None, + max_dim=None, ) -> TypeAlias: - """Implements dtype, dimension and shape constrains on the Dataset.""" + """Implements dtype, dimension and shape constrains on the Dataset. + + Arguments: + shape_constraint (Tuple[Union[None, int],...]): + dtype_constraint (Tuple[DTypeNames,...]): + min_dim (int): + max_dim (int): + + Example: + ``` + class CustomGroup: + x: condataset(dtype_contraint=("int16","int32","int64)) + y: condataset(shape_constraint=(100,)) + z: condataset(min_dim=1, max_dim=1) + + group = CustomGroup(x=,y=,z=) # succeeds as it obeys the constraints + + group = CustomGroup(x=,y=,z=) # fails as it violates the constraints + ``` + + """ return Annotated[ CastDataset, AfterValidator(_constrain_dtype_dataset(dtype_constraint=dtype_constraint)), @@ -158,7 +182,24 @@ def contable( min_dim=None, max_dim=None, ) -> TypeAlias: - """Implements field, dtype, dimension and shape constrains on the Table.""" + """Implements field, dtype, dimension and shape constrains on the Table. + + Example: + ``` + class CustomGroup: + x: contable(dtype_contraint=("int16","int32","int64)) + y: contable(shape_constraint=(100,)) + z: contable(min_dim=1, max_dim=1) + u: contable(required_field=("c1","c2")) + v: contable(required_field=("c1", "c2"), strict_fields=True) + + + group = CustomGroup(x=,y=,z=,u=,v=) # succeeds as it obeys the constraints + + group = CustomGroup(x=,y=,z=,u=,v=) # fails as it violates the constraints + ``` + + """ return Annotated[ CastTable, AfterValidator( @@ -181,7 +222,21 @@ def confolder( min_dim=None, max_dim=None, ) -> TypeAlias: - """Implements dimension and shape constrains on the Folder.""" + """Implements dimension and shape constrains on the Folder. + + Example: + ``` + class CustomGroup: + x: confolder(shape_constraint=(100,)) + y: confolder(min_dim=1, max_dim=1) + + + group = CustomGroup(x=,y=) # succeeds as it obeys the constraints + + group = CustomGroup(x=,y=) # fails as it violates the constraints + ``` + + """ return Annotated[ Folder, AfterValidator(_constrain_dim(min_dim=min_dim, max_dim=max_dim)), diff --git a/src/oqd_dataschema/dataset.py b/src/oqd_dataschema/dataset.py index 678ae90..ee9d7f9 100644 --- a/src/oqd_dataschema/dataset.py +++ b/src/oqd_dataschema/dataset.py @@ -16,7 +16,7 @@ from __future__ import annotations -from typing import Annotated, Any, Literal, Optional, Tuple, Union +from typing import Annotated, Any, Optional, Tuple, Union import numpy as np from pydantic import ( @@ -27,7 +27,7 @@ model_validator, ) -from oqd_dataschema.base import Attrs, DTypes, GroupField +from oqd_dataschema.base import Attrs, DTypeNames, DTypes, GroupField from .utils import _flex_shape_equal @@ -62,7 +62,7 @@ class Dataset(GroupField, extra="forbid"): ``` """ - dtype: Optional[Literal[DTypes.names()]] = None # type: ignore + dtype: Optional[DTypeNames] = None # type: ignore shape: Optional[Tuple[Union[int, None], ...]] = None data: Optional[Any] = Field(default=None, exclude=True) diff --git a/src/oqd_dataschema/folder.py b/src/oqd_dataschema/folder.py index c693673..4e8c41e 100644 --- a/src/oqd_dataschema/folder.py +++ b/src/oqd_dataschema/folder.py @@ -15,7 +15,7 @@ from __future__ import annotations from types import MappingProxyType -from typing import Annotated, Any, Dict, Literal, Optional, Tuple, Union +from typing import Annotated, Any, Dict, Optional, Tuple, Union import numpy as np from pydantic import ( @@ -27,7 +27,7 @@ ) from typing_extensions import TypeAliasType -from oqd_dataschema.base import Attrs, DTypes, GroupField +from oqd_dataschema.base import Attrs, DTypeNames, DTypes, GroupField from oqd_dataschema.utils import _flex_shape_equal ######################################################################################## @@ -38,7 +38,7 @@ DocumentSchema = TypeAliasType( "DocumentSchema", - Dict[str, Union["DocumentSchema", Optional[Literal[DTypes.names()]]]], # type: ignore + Dict[str, Union["DocumentSchema", Optional[DTypeNames]]], # type: ignore ) diff --git a/src/oqd_dataschema/table.py b/src/oqd_dataschema/table.py index 7664c43..da72a57 100644 --- a/src/oqd_dataschema/table.py +++ b/src/oqd_dataschema/table.py @@ -15,7 +15,7 @@ from __future__ import annotations from types import MappingProxyType -from typing import Annotated, Any, List, Literal, Optional, Tuple, Union +from typing import Annotated, Any, List, Optional, Tuple, Union import numpy as np import pandas as pd @@ -27,7 +27,7 @@ model_validator, ) -from oqd_dataschema.base import Attrs, DTypes, GroupField +from oqd_dataschema.base import Attrs, DTypeNames, DTypes, GroupField from oqd_dataschema.utils import ( _flex_shape_equal, _is_list_unique, @@ -43,7 +43,7 @@ ######################################################################################## -Column = Tuple[str, Optional[Literal[DTypes.names()]]] +Column = Tuple[str, Optional[DTypeNames]] class Table(GroupField, extra="forbid"): diff --git a/src/oqd_dataschema/utils.py b/src/oqd_dataschema/utils.py index fc30253..b1153d6 100644 --- a/src/oqd_dataschema/utils.py +++ b/src/oqd_dataschema/utils.py @@ -20,7 +20,10 @@ ######################################################################################## -__all__ = ["_flex_shape_equal", "_validator_from_condition", "_is_list_unique"] +__all__ = [ + "unstructured_to_structured", + "dict_to_structured", +] ######################################################################################## From fd1693e71dfbc24d0743c1b60d0c3fd8edd029bc Mon Sep 17 00:00:00 2001 From: yhteoh Date: Tue, 14 Oct 2025 14:59:38 -0400 Subject: [PATCH 41/48] [fix] edge case where only a single group registered --- src/oqd_dataschema/group.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/oqd_dataschema/group.py b/src/oqd_dataschema/group.py index 685f596..3943bb9 100644 --- a/src/oqd_dataschema/group.py +++ b/src/oqd_dataschema/group.py @@ -149,9 +149,13 @@ def clear(cls): @property def union(cls): """Get the current Union of all registered types""" - return Annotated[ - Union[tuple(cls.groups.values())], Discriminator(discriminator="class_") - ] + + if len(cls.groups) > 1: + return Annotated[ + Union[tuple(cls.groups.values())], Discriminator(discriminator="class_") + ] + else: + return next(iter(cls.groups.values())) @property def adapter(cls): From d507aed5c06d1fd6a5c8a8acff8334f0bf4e041d Mon Sep 17 00:00:00 2001 From: yhteoh Date: Wed, 15 Oct 2025 11:44:22 -0400 Subject: [PATCH 42/48] [feat] Added pipe to Datastore --- src/oqd_dataschema/datastore.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/oqd_dataschema/datastore.py b/src/oqd_dataschema/datastore.py index fd3660d..c7f402b 100644 --- a/src/oqd_dataschema/datastore.py +++ b/src/oqd_dataschema/datastore.py @@ -14,6 +14,8 @@ # %% +from __future__ import annotations + import json import pathlib from typing import Any, Dict, Literal @@ -218,5 +220,8 @@ def add(self, **groups): self.update(**groups) + def pipe(self, func) -> Datastore: + return func(self) + # %% From bf1fba7f88454a5670321a97dbb9e15092c9ef4a Mon Sep 17 00:00:00 2001 From: yhteoh Date: Wed, 15 Oct 2025 11:49:52 -0400 Subject: [PATCH 43/48] [fix] GroupBase allows attrs field --- src/oqd_dataschema/group.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/oqd_dataschema/group.py b/src/oqd_dataschema/group.py index 3943bb9..2d2829d 100644 --- a/src/oqd_dataschema/group.py +++ b/src/oqd_dataschema/group.py @@ -96,12 +96,15 @@ def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) for k, v in cls.__annotations__.items(): - if k in ["class_", "attrs"]: + if k == "class_": + raise AttributeError("`class_` attribute should not be set manually.") + + if k == "attrs" and v is not Attrs: raise AttributeError( - "`class_` and `attrs` attribute should not be set manually." + "`attrs` attribute must have type annotation of Attrs." ) - if cls._is_classvar(v): + if k == "attrs" or cls._is_classvar(v): continue if not cls._is_groupfield_type(v): From 7d18f6ce1342f134866f360281aafa3804bcd5db Mon Sep 17 00:00:00 2001 From: yhteoh Date: Wed, 15 Oct 2025 12:03:06 -0400 Subject: [PATCH 44/48] [packaging] update exports --- src/oqd_dataschema/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/oqd_dataschema/__init__.py b/src/oqd_dataschema/__init__.py index afc6f77..ef09b9f 100644 --- a/src/oqd_dataschema/__init__.py +++ b/src/oqd_dataschema/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .base import Attrs, DTypes from .constrained import condataset, confolder, contable from .dataset import CastDataset, Dataset from .datastore import Datastore @@ -23,6 +24,8 @@ ######################################################################################## __all__ = [ + "Attrs", + "DTypes", "Datastore", "GroupBase", "GroupRegistry", From 5bc133af7916e14c17217077fa3d481f03f11b70 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Wed, 15 Oct 2025 12:19:07 -0400 Subject: [PATCH 45/48] [docs] update tutorial on docs --- docs/tutorial.md | 108 ++++++++++++++++------------------------------- 1 file changed, 37 insertions(+), 71 deletions(-) diff --git a/docs/tutorial.md b/docs/tutorial.md index 943dc1c..ecdc56f 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -1,108 +1,74 @@ - # Tutorial -```python -import pathlib - -import numpy as np -from rich.pretty import pprint - -from oqd_dataschema.base import Dataset -from oqd_dataschema.datastore import Datastore -from oqd_dataschema.groups import ( - ExpectationValueDataGroup, - MeasurementOutcomesDataGroup, - SinaraRawDataGroup, -) -``` +## Group Definition ```python -raw = SinaraRawDataGroup( - camera_images=Dataset(shape=(3, 2, 2), dtype="float32"), - attrs={"date": "2025-03-26", "version": 0.1}, -) -pprint(raw) -``` +from oqd_dataschema import GroupBase, Attrs - - -```python -raw.camera_images.data = np.random.uniform(size=(3, 2, 2)).astype("float32") -pprint(raw) +class CustomGroup(GroupBase): + attrs: Attrs = Field( + default_factory=lambda: dict( + timestamp=str(datetime.datetime.now(datetime.timezone.utc)) + ) + ) + t: Dataset + x: Dataset ``` - +Defined groups are automatically registered into the [`GroupRegistry`][oqd_dataschema.group.GroupRegistry]. ```python -raw.camera_images.data = np.random.uniform(size=(3, 2, 2)).astype("float32") -``` - +from oqd_dataschema import GroupRegistry - -```python -data = Datastore(groups={"raw": raw}) -pprint(data) +GroupRegistry.groups ``` - - +## Initialize Group ```python -def process_raw(raw: SinaraRawDataGroup) -> MeasurementOutcomesDataGroup: - processed = MeasurementOutcomesDataGroup( - outcomes=Dataset( - data=np.round(raw.camera_images.data.mean(axis=(1, 2))), - ) - ) - return processed +t = np.linspace(0, 1, 101).astype(np.float32) +x = np.sin(t).astype(np.complex64) +group = CustomGroup( + t=Dataset(dtype="float32", shape=(101,)), x=Dataset(dtype="complex64", shape=(101,)) +) -processed = process_raw(data.groups["raw"]) -pprint(processed) +group.t.data = t +group.x.data = x ``` - - +## Initialize Datastore ```python -data.groups.update(processed=processed) -pprint(data) +from oqd_datastore import Datastore + +datastore = Datastore(groups={"g1": group}) ``` +## Data pipeline +```python +def process(datastore) -> Datastore: + _g = datastore.get("g1") + g2 = CustomGroup(t=Dataset(data=_g.t.data), x=Dataset(data=_g.x.data + 1j)) -```python -def process_outcomes( - measurements: MeasurementOutcomesDataGroup, -) -> ExpectationValueDataGroup: - expval = ExpectationValueDataGroup( - expectation_value=Dataset( - shape=(), - dtype="float32", - data=measurements.outcomes.data.mean(), - attrs={"date": "20", "input": 10}, - ) - ) - return expval + datastore.add(g2=g2) + return datastore -expval = process_outcomes(processed) -data.groups.update(expval=process_outcomes(data.groups["processed"])) -pprint(expval) +datastore.pipe(process) ``` - +## Save Datastore ```python -filepath = pathlib.Path("test.h5") -data.model_dump_hdf5(filepath) +datastore.model_dump_hdf5(pathlib.Path("datastore.h5"), mode="w") ``` - +## Load Datastore ```python -data_reload = Datastore.model_validate_hdf5(filepath) -pprint(data_reload) +reloaded_datastore = Datastore.model_validate_hdf5(pathlib.Path("datastore.h5")) ``` From db71491899195d0c30c8324f38bf7d5eddcfc9fe Mon Sep 17 00:00:00 2001 From: yhteoh Date: Wed, 15 Oct 2025 12:27:12 -0400 Subject: [PATCH 46/48] [feat] Add validation for func output --- src/oqd_dataschema/datastore.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/oqd_dataschema/datastore.py b/src/oqd_dataschema/datastore.py index c7f402b..8fb5066 100644 --- a/src/oqd_dataschema/datastore.py +++ b/src/oqd_dataschema/datastore.py @@ -18,7 +18,7 @@ import json import pathlib -from typing import Any, Dict, Literal +from typing import Any, Callable, Dict, Literal import h5py from pydantic import ( @@ -220,8 +220,13 @@ def add(self, **groups): self.update(**groups) - def pipe(self, func) -> Datastore: - return func(self) + def pipe(self, func: Callable[[Datastore], None]) -> Datastore: + _result = func(self) + + if _result is not None: + raise ValueError("`func` must return None.") + + return self # %% From 916e71ea60b9140ba9a3268721f32d703a68528d Mon Sep 17 00:00:00 2001 From: yhteoh Date: Wed, 15 Oct 2025 12:37:02 -0400 Subject: [PATCH 47/48] [docs] fix tutorial documentation --- docs/tutorial.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/tutorial.md b/docs/tutorial.md index ecdc56f..c382434 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -52,11 +52,10 @@ def process(datastore) -> Datastore: _g = datastore.get("g1") g2 = CustomGroup(t=Dataset(data=_g.t.data), x=Dataset(data=_g.x.data + 1j)) + g2.attrs["_gen_by_pipe"] = "process" datastore.add(g2=g2) - return datastore - datastore.pipe(process) ``` From ce9246ab976d54d39b405e927a0164742e431c45 Mon Sep 17 00:00:00 2001 From: yhteoh Date: Wed, 15 Oct 2025 13:56:04 -0400 Subject: [PATCH 48/48] [docs] updated tutorials --- docs/tutorials/advanced.md | 82 ++++++++++++++++++++++++++++++++++++++ docs/tutorials/basic.md | 73 +++++++++++++++++++++++++++++++++ mkdocs.yaml | 4 +- 3 files changed, 158 insertions(+), 1 deletion(-) create mode 100644 docs/tutorials/advanced.md create mode 100644 docs/tutorials/basic.md diff --git a/docs/tutorials/advanced.md b/docs/tutorials/advanced.md new file mode 100644 index 0000000..8d9cf31 --- /dev/null +++ b/docs/tutorials/advanced.md @@ -0,0 +1,82 @@ +# Tutorial + +## Group Definition + +```python +from oqd_dataschema import GroupBase, Attrs + +class CustomGroup(GroupBase): + attrs: Attrs = Field( + default_factory=lambda: dict( + timestamp=str(datetime.datetime.now(datetime.timezone.utc)) + ) + ) + dset: Dataset + tbl: Table + fld: Folder +``` + +Defined groups are automatically registered into the [`GroupRegistry`][oqd_dataschema.group.GroupRegistry]. + +```python +from oqd_dataschema import GroupRegistry + +GroupRegistry.groups +``` + +## Initialize Group + +```python +from oqd_dataschema import Dataset, Table, Folder, unstructured_to_structured + +dset = Dataset(data=np.linspace(0, 1, 101).astype(np.float32)) +tbl = Table( + columns=[("t", "float32"), ("x", "complex128")], + data=unstructured_to_structured( + np.stack([np.linspace(0, 1, 101), np.sin(np.linspace(0, 1, 101))], -1), + dtype=np.dtype([("t", np.float32), ("x", np.complex128)]), + ), +) +fld = Folder( + document_schema={"t": "float32", "signal": {"x": "complex128", "y": "complex128"}}, + data=unstructured_to_structured( + np.stack( + [ + np.linspace(0, 1, 101), + np.sin(np.linspace(0, 1, 101)), + np.cos(np.linspace(0, 1, 101)), + ], + -1, + ), + dtype=np.dtype( + [ + ("t", np.float32), + ("signal", np.dtype([("x", np.complex128), ("y", np.complex128)])), + ] + ), + ), +) + + +group = CustomGroup(dset=dset, tbl=tbl, fld=fld) +``` + +## Initialize Datastore + +```python +from oqd_datastore import Datastore + +datastore = Datastore(groups={"g1": group}) +``` + +## Save Datastore + +```python +datastore.model_dump_hdf5(pathlib.Path("datastore.h5"), mode="w") +``` + +## Load Datastore + +```python +reloaded_datastore = Datastore.model_validate_hdf5(pathlib.Path("datastore.h5")) +``` diff --git a/docs/tutorials/basic.md b/docs/tutorials/basic.md new file mode 100644 index 0000000..c382434 --- /dev/null +++ b/docs/tutorials/basic.md @@ -0,0 +1,73 @@ +# Tutorial + +## Group Definition + +```python +from oqd_dataschema import GroupBase, Attrs + +class CustomGroup(GroupBase): + attrs: Attrs = Field( + default_factory=lambda: dict( + timestamp=str(datetime.datetime.now(datetime.timezone.utc)) + ) + ) + t: Dataset + x: Dataset +``` + +Defined groups are automatically registered into the [`GroupRegistry`][oqd_dataschema.group.GroupRegistry]. + +```python +from oqd_dataschema import GroupRegistry + +GroupRegistry.groups +``` + +## Initialize Group + +```python +t = np.linspace(0, 1, 101).astype(np.float32) +x = np.sin(t).astype(np.complex64) + +group = CustomGroup( + t=Dataset(dtype="float32", shape=(101,)), x=Dataset(dtype="complex64", shape=(101,)) +) + +group.t.data = t +group.x.data = x +``` + +## Initialize Datastore + +```python +from oqd_datastore import Datastore + +datastore = Datastore(groups={"g1": group}) +``` + +## Data pipeline + +```python +def process(datastore) -> Datastore: + _g = datastore.get("g1") + + g2 = CustomGroup(t=Dataset(data=_g.t.data), x=Dataset(data=_g.x.data + 1j)) + g2.attrs["_gen_by_pipe"] = "process" + + datastore.add(g2=g2) + + +datastore.pipe(process) +``` + +## Save Datastore + +```python +datastore.model_dump_hdf5(pathlib.Path("datastore.h5"), mode="w") +``` + +## Load Datastore + +```python +reloaded_datastore = Datastore.model_validate_hdf5(pathlib.Path("datastore.h5")) +``` diff --git a/mkdocs.yaml b/mkdocs.yaml index 8981d0c..732c693 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -18,7 +18,9 @@ extra: nav: - Get Started: index.md - - Tutorial: tutorial.md + - Tutorials: + - Basics: tutorials/basic.md + - Datasets/Tables/Folders: tutorials/advanced.md - Explanation: explanation.md - API Reference: - Base: api/base.md