Skip to content

Commit 6d66226

Browse files
committed
[feat] Added str, bytes and bool support for Datasets
1 parent 0a5445c commit 6d66226

File tree

1 file changed

+19
-14
lines changed

1 file changed

+19
-14
lines changed

src/oqd_dataschema/base.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
def _valid_attr_key(value):
4242
if value in invalid_attrs:
4343
raise KeyError
44+
4445
return value
4546

4647

@@ -53,15 +54,19 @@ def _valid_attr_key(value):
5354

5455

5556
# %%
56-
mapping = bidict(
57+
dtype_map = bidict(
5758
{
58-
"int32": np.dtype("int32"),
59-
"int64": np.dtype("int64"),
60-
"float32": np.dtype("float32"),
61-
"float64": np.dtype("float64"),
62-
"complex64": np.dtype("complex64"),
63-
"complex128": np.dtype("complex128"),
64-
# 'string': np.type
59+
"int16": np.dtypes.Int16DType,
60+
"int32": np.dtypes.Int32DType,
61+
"int64": np.dtypes.Int64DType,
62+
"float16": np.dtypes.Float16DType,
63+
"float32": np.dtypes.Float32DType,
64+
"float64": np.dtypes.Float64DType,
65+
"complex64": np.dtypes.Complex64DType,
66+
"complex128": np.dtypes.Complex128DType,
67+
"string": np.dtypes.StrDType,
68+
"bytes": np.dtypes.BytesDType,
69+
"bool": np.dtypes.BoolDType,
6570
}
6671
)
6772

@@ -127,7 +132,7 @@ class Dataset(BaseModel, extra="forbid"):
127132
```
128133
"""
129134

130-
dtype: Optional[Literal[tuple(mapping.keys())]] = None
135+
dtype: Optional[Literal[tuple(dtype_map.keys())]] = None
131136
shape: Optional[tuple[int, ...]] = None
132137
data: Optional[Any] = Field(default=None, exclude=True)
133138

@@ -149,12 +154,12 @@ def validate_and_update(cls, values: dict):
149154
if not isinstance(data, np.ndarray):
150155
raise TypeError("`data` must be a numpy.ndarray.")
151156

152-
if data.dtype not in mapping.values():
157+
if type(data.dtype) not in dtype_map.values():
153158
raise TypeError(
154-
f"`data` must be a numpy array of dtype in {tuple(mapping.keys())}."
159+
f"`data` must be a numpy array of dtype in {tuple(dtype_map.keys())}."
155160
)
156161

157-
values["dtype"] = mapping.inverse[data.dtype]
162+
values["dtype"] = dtype_map.inverse[type(data.dtype)]
158163
values["shape"] = data.shape
159164

160165
return values
@@ -163,8 +168,8 @@ def validate_and_update(cls, values: dict):
163168
def validate_data_matches_shape_dtype(self):
164169
"""Ensure that `data` matches `dtype` and `shape`."""
165170
if self.data is not None:
166-
expected_dtype = mapping[self.dtype]
167-
if self.data.dtype != expected_dtype:
171+
expected_dtype = dtype_map[self.dtype]
172+
if type(self.data.dtype) is not expected_dtype:
168173
raise ValueError(
169174
f"Expected data dtype `{self.dtype}`, but got `{self.data.dtype.name}`."
170175
)

0 commit comments

Comments
 (0)