1
+ import numpy as np
1
2
import torch
2
3
3
4
@@ -12,6 +13,8 @@ def validate(cls, data):
12
13
return torch .tensor (data )
13
14
elif isinstance (data , torch .Tensor ):
14
15
return data
16
+ elif isinstance (data , np .ndarray ):
17
+ return torch .from_numpy (data )
15
18
else :
16
19
return torch .as_tensor (data )
17
20
@@ -28,7 +31,7 @@ def validate(cls, data):
28
31
return InheritTensor
29
32
30
33
@classmethod
31
- def short (cls , dims ):
34
+ def dims (cls , dims ):
32
35
class InheritTensor (cls ):
33
36
@classmethod
34
37
def validate (cls , data ):
@@ -71,12 +74,70 @@ def validate(cls, data):
71
74
72
75
return InheritTensor
73
76
77
+ @classmethod
78
+ def device (cls , device ):
79
+ class InheritTensor (cls ):
80
+ @classmethod
81
+ def validate (cls , data ):
82
+ return super ().validate (data ).to (device )
83
+
84
+ return InheritTensor
85
+
86
+ @classmethod
87
+ def cpu (cls ):
88
+ return cls .device (torch .device ("cpu" ))
89
+
90
+ @classmethod
91
+ def cuda (cls ):
92
+ return cls .device (torch .device ("cuda" ))
93
+
94
+ @classmethod
95
+ def dtype (cls , dtype ):
96
+ class InheritTensor (cls ):
97
+ @classmethod
98
+ def validate (cls , data ):
99
+ data = super ().validate (data )
100
+ new_data = data .type (dtype )
101
+ if not torch .allclose (data .float (), new_data .float (), equal_nan = True ):
102
+ raise ValueError (f"Was unable to cast from { data .dtype } to { dtype } " )
103
+ return new_data
104
+
105
+ return InheritTensor
106
+
107
+ @classmethod
108
+ def float (cls ):
109
+ return cls .dtype (torch .float32 )
110
+
111
+ @classmethod
112
+ def half (cls ):
113
+ return cls .dtype (torch .float16 )
114
+
115
+ @classmethod
116
+ def double (cls ):
117
+ return cls .dtype (torch .float64 )
118
+
119
+ @classmethod
120
+ def int (cls ):
121
+ return cls .dtype (torch .int32 )
122
+
123
+ @classmethod
124
+ def long (cls ):
125
+ return cls .dtype (torch .int64 )
126
+
127
+ @classmethod
128
+ def short (cls ):
129
+ return cls .dtype (torch .int16 )
130
+
131
+ @classmethod
132
+ def uint8 (cls ):
133
+ return cls .dtype (torch .uint8 )
134
+
74
135
75
136
def test_base_model ():
76
137
from pydantic import BaseModel
77
138
78
139
class Test (BaseModel ):
79
- tensor : Tensor .short ( "nchw " )
140
+ tensor : Tensor .dims ( "NCHW " )
80
141
81
142
Test (tensor = torch .ones (10 , 3 , 32 , 32 ))
82
143
@@ -93,8 +154,8 @@ def test_conversion():
93
154
import numpy as np
94
155
95
156
class Test (BaseModel ):
96
- numbers : Tensor .short ("N" )
97
- numbers2 : Tensor .short ("N" )
157
+ numbers : Tensor .dims ("N" )
158
+ numbers2 : Tensor .dims ("N" )
98
159
99
160
Test (
100
161
numbers = [1.1 , 2.1 , 3.1 ],
@@ -106,7 +167,29 @@ def test_chaining():
106
167
from pytest import raises
107
168
108
169
with raises (ValueError ):
109
- Tensor .ndim (4 ).short ("NCH" ).validate (torch .ones (3 , 4 , 5 ))
170
+ Tensor .ndim (4 ).dims ("NCH" ).validate (torch .ones (3 , 4 , 5 ))
171
+
172
+ with raises (ValueError ):
173
+ Tensor .dims ("NCH" ).ndim (4 ).validate (torch .ones (3 , 4 , 5 ))
174
+
175
+
176
+ def test_dtype ():
177
+ from pydantic import BaseModel
178
+ from pytest import raises
179
+
180
+ class Test (BaseModel ):
181
+ numbers : Tensor .uint8 ()
182
+
183
+ Test (numbers = [1 , 2 , 3 ])
110
184
111
185
with raises (ValueError ):
112
- Tensor .short ("NCH" ).ndim (4 ).validate (torch .ones (3 , 4 , 5 ))
186
+ Test (numbers = [1.5 , 2.2 , 3.2 ])
187
+
188
+
189
+ def test_device ():
190
+ from pydantic import BaseModel
191
+
192
+ class Test (BaseModel ):
193
+ numbers : Tensor .float ().cpu ()
194
+
195
+ Test (numbers = [1 , 2 , 3 ])
0 commit comments