Skip to content

Commit de63e9b

Browse files
committed
improve: more features in types tensor and numpy
1 parent d94c1aa commit de63e9b

File tree

4 files changed

+1060
-1040
lines changed

4 files changed

+1060
-1040
lines changed

lantern/numpy.py

Lines changed: 193 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,25 @@
11
import numpy as np
2-
import io
3-
import cv2
2+
import torch
43

54

65
class Numpy(np.ndarray):
7-
@staticmethod
8-
def from_matplotlib_figure(figure):
9-
buffer = io.BytesIO()
10-
figure.savefig(buffer, format="png", dpi=90, bbox_inches="tight")
11-
buffer.seek(0)
12-
image = np.frombuffer(buffer.getvalue(), dtype=np.uint8)
13-
buffer.close()
14-
image = cv2.imdecode(image, 1)
15-
return image
16-
176
@classmethod
187
def __get_validators__(cls):
198
yield cls.validate
209

2110
@classmethod
22-
def validate(cls, data):
11+
def validate(cls, data) -> np.ndarray:
2312
if isinstance(data, cls):
2413
return data.view(np.ndarray)
2514
elif isinstance(data, np.ndarray):
2615
return data
16+
elif isinstance(data, torch.Tensor):
17+
return data.numpy()
2718
else:
2819
return np.array(data)
2920

3021
@classmethod
31-
def ndim(cls, ndim):
22+
def ndim(cls, ndim) -> "Numpy":
3223
class InheritNumpy(cls):
3324
@classmethod
3425
def validate(cls, data):
@@ -40,7 +31,7 @@ def validate(cls, data):
4031
return InheritNumpy
4132

4233
@classmethod
43-
def dims(cls, dims):
34+
def dims(cls, dims) -> "Numpy":
4435
class InheritNumpy(cls):
4536
@classmethod
4637
def validate(cls, data):
@@ -54,7 +45,7 @@ def validate(cls, data):
5445
return InheritNumpy
5546

5647
@classmethod
57-
def shape(cls, *sizes):
48+
def shape(cls, *sizes) -> "Numpy":
5849
class InheritNumpy(cls):
5950
@classmethod
6051
def validate(cls, data):
@@ -67,24 +58,93 @@ def validate(cls, data):
6758
return InheritNumpy
6859

6960
@classmethod
70-
def between(cls, geq, leq):
61+
def between(cls, ge, le) -> "Numpy":
7162
class InheritNumpy(cls):
7263
@classmethod
7364
def validate(cls, data):
7465
data = super().validate(data)
75-
data_min = data.min()
76-
if data_min < geq:
77-
raise ValueError(f"Expected min value {geq}, got {data_min}")
7866

79-
data_max = data.min()
80-
if data_max > leq:
81-
raise ValueError(f"Expected max value {leq}, got {data_max}")
67+
if data.min() < ge:
68+
raise ValueError(
69+
f"Expected greater than or equal to {ge}, got {data.min()}"
70+
)
71+
72+
if data.max() > le:
73+
raise ValueError(
74+
f"Expected less than or equal to {le}, got {data.max()}"
75+
)
8276
return data
8377

8478
return InheritNumpy
8579

8680
@classmethod
87-
def dtype(cls, dtype):
81+
def ge(cls, ge) -> "Numpy":
82+
class InheritTensor(cls):
83+
@classmethod
84+
def validate(cls, data):
85+
data = super().validate(data)
86+
if data.min() < ge:
87+
raise ValueError(
88+
f"Expected greater than or equal to {ge}, got {data.min()}"
89+
)
90+
91+
return InheritTensor
92+
93+
@classmethod
94+
def le(cls, le) -> "Numpy":
95+
class InheritTensor(cls):
96+
@classmethod
97+
def validate(cls, data):
98+
data = super().validate(data)
99+
100+
if data.max() > le:
101+
raise ValueError(
102+
f"Expected less than or equal to {le}, got {data.max()}"
103+
)
104+
return data
105+
106+
return InheritTensor
107+
108+
@classmethod
109+
def gt(cls, gt) -> "Numpy":
110+
class InheritTensor(cls):
111+
@classmethod
112+
def validate(cls, data):
113+
data = super().validate(data)
114+
115+
if data.min() <= gt:
116+
raise ValueError(f"Expected greater than {gt}, got {data.min()}")
117+
118+
return InheritTensor
119+
120+
@classmethod
121+
def lt(cls, lt) -> "Numpy":
122+
class InheritTensor(cls):
123+
@classmethod
124+
def validate(cls, data):
125+
data = super().validate(data)
126+
127+
if data.max() >= lt:
128+
raise ValueError(f"Expected less than {lt}, got {data.max()}")
129+
return data
130+
131+
return InheritTensor
132+
133+
@classmethod
134+
def ne(cls, ne) -> "Numpy":
135+
class InheritTensor(cls):
136+
@classmethod
137+
def validate(cls, data):
138+
data = super().validate(data)
139+
140+
if (data == ne).any():
141+
raise ValueError(f"Unexpected value {ne}")
142+
return data
143+
144+
return InheritTensor
145+
146+
@classmethod
147+
def dtype(cls, dtype) -> "Numpy":
88148
class InheritNumpy(cls):
89149
@classmethod
90150
def validate(cls, data):
@@ -96,6 +156,66 @@ def validate(cls, data):
96156

97157
return InheritNumpy
98158

159+
@classmethod
160+
def float(cls) -> "Numpy":
161+
return cls.dtype(np.float32)
162+
163+
@classmethod
164+
def float32(cls) -> "Numpy":
165+
return cls.dtype(np.float32)
166+
167+
@classmethod
168+
def half(cls) -> "Numpy":
169+
return cls.dtype(np.float16)
170+
171+
@classmethod
172+
def float16(cls):
173+
return cls.dtype(np.float16)
174+
175+
@classmethod
176+
def double(cls) -> "Numpy":
177+
return cls.dtype(np.float64)
178+
179+
@classmethod
180+
def float64(cls) -> "Numpy":
181+
return cls.dtype(np.float64)
182+
183+
@classmethod
184+
def int(cls) -> "Numpy":
185+
return cls.dtype(np.int32)
186+
187+
@classmethod
188+
def int32(cls) -> "Numpy":
189+
return cls.dtype(np.int32)
190+
191+
@classmethod
192+
def long(cls) -> "Numpy":
193+
return cls.dtype(np.int64)
194+
195+
@classmethod
196+
def int64(cls) -> "Numpy":
197+
return cls.dtype(np.int64)
198+
199+
@classmethod
200+
def short(cls) -> "Numpy":
201+
return cls.dtype(np.int16)
202+
203+
@classmethod
204+
def int16(cls) -> "Numpy":
205+
return cls.dtype(np.int16)
206+
207+
@classmethod
208+
def byte(cls) -> "Numpy":
209+
return cls.dtype(np.uint8)
210+
211+
@classmethod
212+
def uint8(cls) -> "Numpy":
213+
return cls.dtype(np.uint8)
214+
215+
@classmethod
216+
def bool(cls) -> "Numpy":
217+
return cls.dtype(bool)
218+
99219

100220
def test_base_model():
101221
from pydantic import BaseModel
@@ -143,9 +263,57 @@ def test_dtype():
143263
from pytest import raises
144264

145265
class Test(BaseModel):
146-
numbers: Numpy.dtype(np.uint8)
266+
numbers: Numpy.uint8()
147267

148268
Test(numbers=[1, 2, 3])
149269

150270
with raises(ValueError):
151271
Test(numbers=[1.5, 2.2, 3.2])
272+
273+
class TestBool(BaseModel):
274+
flags: Numpy.bool()
275+
276+
TestBool(flags=[True, False, True])
277+
278+
with raises(ValueError):
279+
TestBool(numbers=[1.5, 2.2, 3.2])
280+
281+
282+
def test_from_torch():
283+
import torch
284+
from pydantic import BaseModel
285+
286+
class Test(BaseModel):
287+
numbers: Numpy
288+
289+
numbers = torch.tensor([1, 2, 3])
290+
numpy_numbers = Test(numbers=numbers).numbers
291+
292+
assert type(numpy_numbers) == np.ndarray
293+
assert np.allclose(torch.from_numpy(numpy_numbers), numbers)
294+
295+
296+
def test_between():
297+
from pydantic import BaseModel
298+
from pytest import raises
299+
300+
class Test(BaseModel):
301+
numbers: Numpy.between(1, 3.5)
302+
303+
Test(numbers=[1.5, 2.2, 3.2])
304+
305+
with raises(ValueError):
306+
Test(numbers=[-1.5, 2.2, 3.2])
307+
308+
309+
def test_gt():
310+
from pydantic import BaseModel
311+
from pytest import raises
312+
313+
class Test(BaseModel):
314+
numbers: Numpy.gt(1)
315+
316+
Test(numbers=[1.5, 2.2, 3.2])
317+
318+
with raises(ValueError):
319+
Test(numbers=[1, 2.2, 3.2])

0 commit comments

Comments
 (0)