1616import typing
1717import warnings
1818from enum import Enum
19- from functools import partial , reduce
19+ from types import NoneType
2020from typing import Annotated , Any , ClassVar , Literal , Optional , Sequence , Tuple , Union
2121
2222import numpy as np
2828 Discriminator ,
2929 Field ,
3030 TypeAdapter ,
31+ field_validator ,
3132 model_validator ,
3233)
3334
35+ from .utils import _flex_shape_equal , _validator_from_condition
36+
3437########################################################################################
3538
3639__all__ = ["GroupBase" , "Dataset" , "GroupRegistry" , "condataset" , "CastDataset" ]
@@ -108,7 +111,7 @@ class Dataset(BaseModel, extra="forbid"):
108111 """
109112
110113 dtype : Optional [Literal [DTypes .names ()]] = None # type: ignore
111- shape : Optional [Tuple [int , ...]] = None
114+ shape : Optional [Tuple [Union [ int , None ] , ...]] = None
112115 data : Optional [Any ] = Field (default = None , exclude = True )
113116
114117 attrs : Attrs = {}
@@ -117,51 +120,58 @@ class Dataset(BaseModel, extra="forbid"):
117120 use_enum_values = False , arbitrary_types_allowed = True , validate_assignment = True
118121 )
119122
123+ @field_validator ("data" , mode = "before" )
120124 @classmethod
121- def cast (cls , data ):
122- if isinstance ( data , np . ndarray ):
123- return cls ( data = data )
124- return data
125+ def validate_and_update (cls , value ):
126+ # check if data exist
127+ if value is None :
128+ return value
125129
126- @model_validator (mode = "before" )
127- @classmethod
128- def validate_and_update (cls , values : dict ):
129- data = values .get ("data" )
130- dtype = values .get ("dtype" )
131- shape = values .get ("shape" )
130+ # check if data is a numpy array
131+ if not isinstance (value , np .ndarray ):
132+ raise TypeError ("`data` must be a numpy.ndarray." )
132133
133- if data is None and (dtype is not None and shape is not None ):
134- return values
134+ return value
135135
136- elif data is not None and ( dtype is None and shape is None ):
137- if not isinstance ( data , np . ndarray ):
138- raise TypeError ( " `data` must be a numpy.ndarray." )
136+ @ model_validator ( mode = "after" )
137+ def validate_data_matches_shape_dtype ( self ):
138+ """Ensure that `data` matches `dtype` and `shape`."""
139139
140- if type (data .dtype ) not in DTypes :
141- raise TypeError (
142- f"`data` must be a numpy array of dtype in { tuple (DTypes .names ())} ."
143- )
140+ # check if data exist
141+ if self .data is None :
142+ return self
143+
144+ # check if dtype matches data
145+ if (
146+ self .dtype is not None
147+ and type (self .data .dtype ) is not DTypes .get (self .dtype ).value
148+ ):
149+ raise ValueError (
150+ f"Expected data dtype `{ self .dtype } `, but got `{ self .data .dtype .name } `."
151+ )
144152
145- values ["dtype" ] = DTypes (type (data .dtype )).name .lower ()
146- values ["shape" ] = data .shape
153+ # check if shape mataches data
154+ if self .shape is not None and not _flex_shape_equal (
155+ self .data .shape , self .shape
156+ ):
157+ raise ValueError (f"Expected shape { self .shape } , but got { self .data .shape } ." )
147158
148- return values
159+ # reassign dtype if it is None
160+ if self .dtype != DTypes (type (self .data .dtype )).name .lower ():
161+ self .dtype = DTypes (type (self .data .dtype )).name .lower ()
162+
163+ # resassign shape to concrete value if it is None or a flexible shape
164+ if self .shape != self .data .shape :
165+ self .shape = self .data .shape
149166
150- @model_validator (mode = "after" )
151- def validate_data_matches_shape_dtype (self ):
152- """Ensure that `data` matches `dtype` and `shape`."""
153- if self .data is not None :
154- expected_dtype = DTypes .get (self .dtype ).value
155- if type (self .data .dtype ) is not expected_dtype :
156- raise ValueError (
157- f"Expected data dtype `{ self .dtype } `, but got `{ self .data .dtype .name } `."
158- )
159- if self .data .shape != self .shape :
160- raise ValueError (
161- f"Expected shape { self .shape } , but got { self .data .shape } ."
162- )
163167 return self
164168
169+ @classmethod
170+ def cast (cls , data ):
171+ if isinstance (data , np .ndarray ):
172+ return cls (data = data )
173+ return data
174+
165175 def __getitem__ (self , idx ):
166176 return self .data [idx ]
167177
@@ -172,6 +182,12 @@ def _is_dataset_type(cls, type_):
172182 )
173183
174184
185+ CastDataset = Annotated [Dataset , BeforeValidator (Dataset .cast )]
186+
187+ ########################################################################################
188+
189+
190+ @_validator_from_condition
175191def _constrain_dtype (dataset , * , dtype_constraint = None ):
176192 if (not isinstance (dtype_constraint , str )) and isinstance (
177193 dtype_constraint , Sequence
@@ -185,10 +201,12 @@ def _constrain_dtype(dataset, *, dtype_constraint=None):
185201 f"Expected dtype to be of type one of { dtype_constraint } , but got { dataset .dtype } ."
186202 )
187203
188- return dataset
189-
190204
205+ @_validator_from_condition
191206def _constraint_dim (dataset , * , min_dim = None , max_dim = None ):
207+ if min_dim is not None and max_dim is not None and min_dim > max_dim :
208+ raise ValueError ("Impossible to satisfy dimension constraints on dataset." )
209+
192210 min_dim = 0 if min_dim is None else min_dim
193211
194212 dims = len (dataset .shape )
@@ -198,42 +216,26 @@ def _constraint_dim(dataset, *, min_dim=None, max_dim=None):
198216 f"Expected { min_dim } <= dimension of shape{ f' <= { max_dim } ' } , but got shape = { dataset .shape } ."
199217 )
200218
201- return dataset
202-
203219
220+ @_validator_from_condition
204221def _constraint_shape (dataset , * , shape_constraint = None ):
205- if shape_constraint and (
206- len (shape_constraint ) != len (dataset .shape )
207- or reduce (
208- lambda x , y : x or y ,
209- map (
210- lambda x : x [0 ] is not None and x [0 ] != x [1 ],
211- zip (shape_constraint , dataset .shape ),
212- ),
213- )
214- ):
222+ if shape_constraint and not _flex_shape_equal (shape_constraint , dataset .shape ):
215223 raise ValueError (
216224 f"Expected shape to be { shape_constraint } , but got { dataset .shape } ."
217225 )
218226
219- return dataset
220-
221227
222228def condataset (
223229 * , shape_constraint = None , dtype_constraint = None , min_dim = None , max_dim = None
224230):
225231 return Annotated [
226- Dataset ,
227- BeforeValidator (Dataset .cast ),
228- AfterValidator (partial (_constrain_dtype , dtype_constraint = dtype_constraint )),
229- AfterValidator (partial (_constraint_dim , min_dim = min_dim , max_dim = max_dim )),
230- AfterValidator (partial (_constraint_shape , shape_constraint = shape_constraint )),
232+ CastDataset ,
233+ AfterValidator (_constrain_dtype (dtype_constraint = dtype_constraint )),
234+ AfterValidator (_constraint_dim (min_dim = min_dim , max_dim = max_dim )),
235+ AfterValidator (_constraint_shape (shape_constraint = shape_constraint )),
231236 ]
232237
233238
234- CastDataset = Annotated [Dataset , BeforeValidator (Dataset .cast )]
235-
236-
237239########################################################################################
238240
239241
@@ -255,27 +257,46 @@ class GroupBase(BaseModel, extra="forbid"):
255257
256258 attrs : Attrs = {}
257259
260+ @classmethod
261+ def _is_allowed_field_type (cls , v ):
262+ is_dataset = Dataset ._is_dataset_type (v )
263+
264+ is_annotated_dataset = typing .get_origin (
265+ v
266+ ) is Annotated and Dataset ._is_dataset_type (v .__origin__ )
267+
268+ is_optional_dataset = typing .get_origin (v ) is Union and (
269+ (v .__args__ [0 ] == NoneType and Dataset ._is_dataset_type (v .__args__ [1 ]))
270+ or (v .__args__ [1 ] == NoneType and Dataset ._is_dataset_type (v .__args__ [0 ]))
271+ )
272+
273+ is_dict_dataset = (
274+ typing .get_origin (v ) is dict
275+ and v .__args__ [0 ] is str
276+ and Dataset ._is_dataset_type (v .__args__ [1 ])
277+ )
278+
279+ return (
280+ is_dataset or is_annotated_dataset or is_optional_dataset or is_dict_dataset
281+ )
282+
283+ @classmethod
284+ def _is_classvar (cls , v ):
285+ return v is ClassVar or typing .get_origin (v ) is ClassVar
286+
258287 def __init_subclass__ (cls , ** kwargs ):
259288 super ().__init_subclass__ (** kwargs )
260289
261290 for k , v in cls .__annotations__ .items ():
262- if k == "class_" :
263- raise AttributeError ("`class_` attribute should not be set manually." )
264-
265- if k == "attrs" and k is not Attrs :
266- raise TypeError ("`attrs` should be of type `Attrs`" )
267-
268- if (
269- k not in ["class_" , "attrs" ]
270- and v is not ClassVar
271- and not Dataset ._is_dataset_type (v )
272- and not (typing .get_origin (v ) is Annotated and v .__origin__ is Dataset )
273- and not (
274- typing .get_origin (v ) is dict
275- and v .__args__ [0 ] is str
276- and Dataset ._is_dataset_type (v .__args__ [1 ])
291+ if k in ["class_" , "attrs" ]:
292+ raise AttributeError (
293+ "`class_` and `attrs` attribute should not be set manually."
277294 )
278- ):
295+
296+ if cls ._is_classvar (v ):
297+ continue
298+
299+ if not cls ._is_allowed_field_type (v ):
279300 raise TypeError (
280301 "All fields of `GroupBase` have to be of type `Dataset`."
281302 )
0 commit comments