|
33 | 33 |
|
34 | 34 | ######################################################################################## |
35 | 35 |
|
36 | | -__all__ = ["GroupBase", "Dataset", "GroupRegistry", "condataset"] |
| 36 | +__all__ = ["GroupBase", "Dataset", "GroupRegistry", "condataset", "CastDataset"] |
37 | 37 |
|
38 | 38 | ######################################################################################## |
39 | 39 |
|
@@ -117,6 +117,12 @@ class Dataset(BaseModel, extra="forbid"): |
117 | 117 | use_enum_values=False, arbitrary_types_allowed=True, validate_assignment=True |
118 | 118 | ) |
119 | 119 |
|
| 120 | + @classmethod |
| 121 | + def cast(cls, data): |
| 122 | + if isinstance(data, np.ndarray): |
| 123 | + return cls(data=data) |
| 124 | + return data |
| 125 | + |
120 | 126 | @model_validator(mode="before") |
121 | 127 | @classmethod |
122 | 128 | def validate_and_update(cls, values: dict): |
@@ -212,12 +218,16 @@ def condataset( |
212 | 218 | ): |
213 | 219 | return Annotated[ |
214 | 220 | Dataset, |
| 221 | + BeforeValidator(Dataset.cast), |
215 | 222 | AfterValidator(partial(_constrain_dtype, dtype_constraint=dtype_constraint)), |
216 | 223 | AfterValidator(partial(_constraint_dim, min_dim=min_dim, max_dim=max_dim)), |
217 | 224 | AfterValidator(partial(_constraint_shape, shape_constraint=shape_constraint)), |
218 | 225 | ] |
219 | 226 |
|
220 | 227 |
|
| 228 | +CastDataset = Annotated[Dataset, BeforeValidator(Dataset.cast)] |
| 229 | + |
| 230 | + |
221 | 231 | ######################################################################################## |
222 | 232 |
|
223 | 233 |
|
|
0 commit comments