1
1
from __future__ import annotations
2
+
2
3
import numpy as np
3
4
import torch
5
+ import torch ._C
6
+
7
+
8
+ class MetaTensor (torch ._C ._TensorMeta ):
9
+ def __getitem__ (self , validate : Tensor ) -> Tensor :
10
+ return validate
4
11
5
12
6
- class Tensor (torch .Tensor ):
13
+ class Tensor (torch .Tensor , metaclass = MetaTensor ):
14
+ @classmethod
15
+ def Validate (cls ) -> Tensor :
16
+ return cls
17
+
7
18
@classmethod
8
19
def __get_validators__ (cls ):
9
20
yield cls .validate
@@ -234,6 +245,9 @@ def bool(cls) -> Tensor:
234
245
return cls .dtype (torch .bool )
235
246
236
247
248
+ Validate = Tensor
249
+
250
+
237
251
def test_base_model ():
238
252
from pydantic import BaseModel
239
253
@@ -251,8 +265,8 @@ def test_validate():
251
265
252
266
253
267
def test_conversion ():
254
- from pydantic import BaseModel
255
268
import numpy as np
269
+ from pydantic import BaseModel
256
270
257
271
class Test (BaseModel ):
258
272
numbers : Tensor .dims ("N" )
@@ -341,3 +355,16 @@ class Test(BaseModel):
341
355
342
356
with raises (ValueError ):
343
357
Test (numbers = [1 , 2.2 , 3.2 ])
358
+
359
+
360
+ def test_alternative_syntax ():
361
+ from pydantic import BaseModel
362
+ from pytest import raises
363
+
364
+ class Test (BaseModel ):
365
+ numbers : Tensor [Validate .ne (1 )]
366
+
367
+ Test (numbers = [1.5 , 2.2 , 3.2 ])
368
+
369
+ with raises (ValueError ):
370
+ Test (numbers = [1 , 2.2 , 3.2 ])
0 commit comments