4141def _valid_attr_key (value ):
4242 if value in invalid_attrs :
4343 raise KeyError
44+
4445 return value
4546
4647
@@ -53,15 +54,19 @@ def _valid_attr_key(value):
5354
5455
5556# %%
56- mapping = bidict (
57+ dtype_map = bidict (
5758 {
58- "int32" : np .dtype ("int32" ),
59- "int64" : np .dtype ("int64" ),
60- "float32" : np .dtype ("float32" ),
61- "float64" : np .dtype ("float64" ),
62- "complex64" : np .dtype ("complex64" ),
63- "complex128" : np .dtype ("complex128" ),
64- # 'string': np.type
59+ "int16" : np .dtypes .Int16DType ,
60+ "int32" : np .dtypes .Int32DType ,
61+ "int64" : np .dtypes .Int64DType ,
62+ "float16" : np .dtypes .Float16DType ,
63+ "float32" : np .dtypes .Float32DType ,
64+ "float64" : np .dtypes .Float64DType ,
65+ "complex64" : np .dtypes .Complex64DType ,
66+ "complex128" : np .dtypes .Complex128DType ,
67+ "string" : np .dtypes .StrDType ,
68+ "bytes" : np .dtypes .BytesDType ,
69+ "bool" : np .dtypes .BoolDType ,
6570 }
6671)
6772
@@ -127,7 +132,7 @@ class Dataset(BaseModel, extra="forbid"):
127132 ```
128133 """
129134
130- dtype : Optional [Literal [tuple (mapping .keys ())]] = None
135+ dtype : Optional [Literal [tuple (dtype_map .keys ())]] = None
131136 shape : Optional [tuple [int , ...]] = None
132137 data : Optional [Any ] = Field (default = None , exclude = True )
133138
@@ -149,12 +154,12 @@ def validate_and_update(cls, values: dict):
149154 if not isinstance (data , np .ndarray ):
150155 raise TypeError ("`data` must be a numpy.ndarray." )
151156
152- if data .dtype not in mapping .values ():
157+ if type ( data .dtype ) not in dtype_map .values ():
153158 raise TypeError (
154- f"`data` must be a numpy array of dtype in { tuple (mapping .keys ())} ."
159+ f"`data` must be a numpy array of dtype in { tuple (dtype_map .keys ())} ."
155160 )
156161
157- values ["dtype" ] = mapping .inverse [data .dtype ]
162+ values ["dtype" ] = dtype_map .inverse [type ( data .dtype ) ]
158163 values ["shape" ] = data .shape
159164
160165 return values
@@ -163,8 +168,8 @@ def validate_and_update(cls, values: dict):
163168 def validate_data_matches_shape_dtype (self ):
164169 """Ensure that `data` matches `dtype` and `shape`."""
165170 if self .data is not None :
166- expected_dtype = mapping [self .dtype ]
167- if self .data .dtype != expected_dtype :
171+ expected_dtype = dtype_map [self .dtype ]
172+ if type ( self .data .dtype ) is not expected_dtype :
168173 raise ValueError (
169174 f"Expected data dtype `{ self .dtype } `, but got `{ self .data .dtype .name } `."
170175 )
0 commit comments