1
1
import numpy as np
2
- import io
3
- import cv2
2
+ import torch
4
3
5
4
6
5
class Numpy (np .ndarray ):
7
- @staticmethod
8
- def from_matplotlib_figure (figure ):
9
- buffer = io .BytesIO ()
10
- figure .savefig (buffer , format = "png" , dpi = 90 , bbox_inches = "tight" )
11
- buffer .seek (0 )
12
- image = np .frombuffer (buffer .getvalue (), dtype = np .uint8 )
13
- buffer .close ()
14
- image = cv2 .imdecode (image , 1 )
15
- return image
16
-
17
6
@classmethod
18
7
def __get_validators__ (cls ):
19
8
yield cls .validate
20
9
21
10
@classmethod
22
- def validate (cls , data ):
11
+ def validate (cls , data ) -> np . ndarray :
23
12
if isinstance (data , cls ):
24
13
return data .view (np .ndarray )
25
14
elif isinstance (data , np .ndarray ):
26
15
return data
16
+ elif isinstance (data , torch .Tensor ):
17
+ return data .numpy ()
27
18
else :
28
19
return np .array (data )
29
20
30
21
@classmethod
31
- def ndim (cls , ndim ):
22
+ def ndim (cls , ndim ) -> "Numpy" :
32
23
class InheritNumpy (cls ):
33
24
@classmethod
34
25
def validate (cls , data ):
@@ -40,7 +31,7 @@ def validate(cls, data):
40
31
return InheritNumpy
41
32
42
33
@classmethod
43
- def dims (cls , dims ):
34
+ def dims (cls , dims ) -> "Numpy" :
44
35
class InheritNumpy (cls ):
45
36
@classmethod
46
37
def validate (cls , data ):
@@ -54,7 +45,7 @@ def validate(cls, data):
54
45
return InheritNumpy
55
46
56
47
@classmethod
57
- def shape (cls , * sizes ):
48
+ def shape (cls , * sizes ) -> "Numpy" :
58
49
class InheritNumpy (cls ):
59
50
@classmethod
60
51
def validate (cls , data ):
@@ -67,24 +58,93 @@ def validate(cls, data):
67
58
return InheritNumpy
68
59
69
60
@classmethod
70
- def between (cls , geq , leq ) :
61
+ def between (cls , ge , le ) -> "Numpy" :
71
62
class InheritNumpy (cls ):
72
63
@classmethod
73
64
def validate (cls , data ):
74
65
data = super ().validate (data )
75
- data_min = data .min ()
76
- if data_min < geq :
77
- raise ValueError (f"Expected min value { geq } , got { data_min } " )
78
66
79
- data_max = data .min ()
80
- if data_max > leq :
81
- raise ValueError (f"Expected max value { leq } , got { data_max } " )
67
+ if data .min () < ge :
68
+ raise ValueError (
69
+ f"Expected greater than or equal to { ge } , got { data .min ()} "
70
+ )
71
+
72
+ if data .max () > le :
73
+ raise ValueError (
74
+ f"Expected less than or equal to { le } , got { data .max ()} "
75
+ )
82
76
return data
83
77
84
78
return InheritNumpy
85
79
86
80
@classmethod
87
- def dtype (cls , dtype ):
81
+ def ge (cls , ge ) -> "Numpy" :
82
+ class InheritTensor (cls ):
83
+ @classmethod
84
+ def validate (cls , data ):
85
+ data = super ().validate (data )
86
+ if data .min () < ge :
87
+ raise ValueError (
88
+ f"Expected greater than or equal to { ge } , got { data .min ()} "
89
+ )
90
+
91
+ return InheritTensor
92
+
93
+ @classmethod
94
+ def le (cls , le ) -> "Numpy" :
95
+ class InheritTensor (cls ):
96
+ @classmethod
97
+ def validate (cls , data ):
98
+ data = super ().validate (data )
99
+
100
+ if data .max () > le :
101
+ raise ValueError (
102
+ f"Expected less than or equal to { le } , got { data .max ()} "
103
+ )
104
+ return data
105
+
106
+ return InheritTensor
107
+
108
+ @classmethod
109
+ def gt (cls , gt ) -> "Numpy" :
110
+ class InheritTensor (cls ):
111
+ @classmethod
112
+ def validate (cls , data ):
113
+ data = super ().validate (data )
114
+
115
+ if data .min () <= gt :
116
+ raise ValueError (f"Expected greater than { gt } , got { data .min ()} " )
117
+
118
+ return InheritTensor
119
+
120
+ @classmethod
121
+ def lt (cls , lt ) -> "Numpy" :
122
+ class InheritTensor (cls ):
123
+ @classmethod
124
+ def validate (cls , data ):
125
+ data = super ().validate (data )
126
+
127
+ if data .max () >= lt :
128
+ raise ValueError (f"Expected less than { lt } , got { data .max ()} " )
129
+ return data
130
+
131
+ return InheritTensor
132
+
133
+ @classmethod
134
+ def ne (cls , ne ) -> "Numpy" :
135
+ class InheritTensor (cls ):
136
+ @classmethod
137
+ def validate (cls , data ):
138
+ data = super ().validate (data )
139
+
140
+ if (data == ne ).any ():
141
+ raise ValueError (f"Unexpected value { ne } " )
142
+ return data
143
+
144
+ return InheritTensor
145
+
146
+ @classmethod
147
+ def dtype (cls , dtype ) -> "Numpy" :
88
148
class InheritNumpy (cls ):
89
149
@classmethod
90
150
def validate (cls , data ):
@@ -96,6 +156,66 @@ def validate(cls, data):
96
156
97
157
return InheritNumpy
98
158
159
+ @classmethod
160
+ def float (cls ) -> "Numpy" :
161
+ return cls .dtype (np .float32 )
162
+
163
+ @classmethod
164
+ def float32 (cls ) -> "Numpy" :
165
+ return cls .dtype (np .float32 )
166
+
167
+ @classmethod
168
+ def half (cls ) -> "Numpy" :
169
+ return cls .dtype (np .float16 )
170
+
171
+ @classmethod
172
+ def float16 (cls ):
173
+ return cls .dtype (np .float16 )
174
+
175
+ @classmethod
176
+ def double (cls ) -> "Numpy" :
177
+ return cls .dtype (np .float64 )
178
+
179
+ @classmethod
180
+ def float64 (cls ) -> "Numpy" :
181
+ return cls .dtype (np .float64 )
182
+
183
+ @classmethod
184
+ def int (cls ) -> "Numpy" :
185
+ return cls .dtype (np .int32 )
186
+
187
+ @classmethod
188
+ def int32 (cls ) -> "Numpy" :
189
+ return cls .dtype (np .int32 )
190
+
191
+ @classmethod
192
+ def long (cls ) -> "Numpy" :
193
+ return cls .dtype (np .int64 )
194
+
195
+ @classmethod
196
+ def int64 (cls ) -> "Numpy" :
197
+ return cls .dtype (np .int64 )
198
+
199
+ @classmethod
200
+ def short (cls ) -> "Numpy" :
201
+ return cls .dtype (np .int16 )
202
+
203
+ @classmethod
204
+ def int16 (cls ) -> "Numpy" :
205
+ return cls .dtype (np .int16 )
206
+
207
+ @classmethod
208
+ def byte (cls ) -> "Numpy" :
209
+ return cls .dtype (np .uint8 )
210
+
211
+ @classmethod
212
+ def uint8 (cls ) -> "Numpy" :
213
+ return cls .dtype (np .uint8 )
214
+
215
+ @classmethod
216
+ def bool (cls ) -> "Numpy" :
217
+ return cls .dtype (bool )
218
+
99
219
100
220
def test_base_model ():
101
221
from pydantic import BaseModel
@@ -143,9 +263,57 @@ def test_dtype():
143
263
from pytest import raises
144
264
145
265
class Test (BaseModel ):
146
- numbers : Numpy .dtype ( np . uint8 )
266
+ numbers : Numpy .uint8 ( )
147
267
148
268
Test (numbers = [1 , 2 , 3 ])
149
269
150
270
with raises (ValueError ):
151
271
Test (numbers = [1.5 , 2.2 , 3.2 ])
272
+
273
+ class TestBool (BaseModel ):
274
+ flags : Numpy .bool ()
275
+
276
+ TestBool (flags = [True , False , True ])
277
+
278
+ with raises (ValueError ):
279
+ TestBool (numbers = [1.5 , 2.2 , 3.2 ])
280
+
281
+
282
+ def test_from_torch ():
283
+ import torch
284
+ from pydantic import BaseModel
285
+
286
+ class Test (BaseModel ):
287
+ numbers : Numpy
288
+
289
+ numbers = torch .tensor ([1 , 2 , 3 ])
290
+ numpy_numbers = Test (numbers = numbers ).numbers
291
+
292
+ assert type (numpy_numbers ) == np .ndarray
293
+ assert np .allclose (torch .from_numpy (numpy_numbers ), numbers )
294
+
295
+
296
+ def test_between ():
297
+ from pydantic import BaseModel
298
+ from pytest import raises
299
+
300
+ class Test (BaseModel ):
301
+ numbers : Numpy .between (1 , 3.5 )
302
+
303
+ Test (numbers = [1.5 , 2.2 , 3.2 ])
304
+
305
+ with raises (ValueError ):
306
+ Test (numbers = [- 1.5 , 2.2 , 3.2 ])
307
+
308
+
309
+ def test_gt ():
310
+ from pydantic import BaseModel
311
+ from pytest import raises
312
+
313
+ class Test (BaseModel ):
314
+ numbers : Numpy .gt (1 )
315
+
316
+ Test (numbers = [1.5 , 2.2 , 3.2 ])
317
+
318
+ with raises (ValueError ):
319
+ Test (numbers = [1 , 2.2 , 3.2 ])
0 commit comments