Skip to content

Commit 2ce8737

Browse files
committed
revert: add back numpy and tensor types
1 parent 8b2209b commit 2ce8737

File tree

5 files changed

+683
-38
lines changed

5 files changed

+683
-38
lines changed

lantern/__init__.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
1-
from lantern.numpy_seed import numpy_seed
2-
from lantern.star import star
3-
from lantern.functional_base import FunctionalBase
4-
from pytorch_types import Tensor
5-
from pytorch_types import Numpy
6-
from lantern.epochs import Epochs
7-
from lantern.metric import Metric
8-
from lantern.metric_table import MetricTable
9-
from lantern.lambda_module import Lambda
10-
from lantern.module_device import module_device
11-
from lantern.module_train import module_train, module_eval
12-
from lantern.requires_grad import requires_grad, requires_nograd
13-
from lantern.set_learning_rate import set_learning_rate
14-
from lantern.set_seeds import set_seeds
15-
from lantern.worker_init_fn import worker_init_fn
16-
from lantern.progress_bar import ProgressBar
1+
from .numpy_seed import numpy_seed
2+
from .star import star
3+
from .functional_base import FunctionalBase
4+
from .tensor import Tensor
5+
from .numpy import Numpy
6+
from .epochs import Epochs
7+
from .metric import Metric
8+
from .metric_table import MetricTable
9+
from .lambda_module import Lambda
10+
from .module_device import module_device
11+
from .module_train import module_train, module_eval
12+
from .requires_grad import requires_grad, requires_nograd
13+
from .set_learning_rate import set_learning_rate
14+
from .set_seeds import set_seeds
15+
from .worker_init_fn import worker_init_fn
16+
from .progress_bar import ProgressBar
1717

1818
try:
19-
from lantern.early_stopping import EarlyStopping
19+
from .early_stopping import EarlyStopping
2020
except ImportError:
2121
pass
22-
from lantern.git_info import git_info
22+
from .git_info import git_info
2323

2424
from pkg_resources import get_distribution, DistributionNotFound
2525

lantern/numpy.py

Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
from __future__ import annotations
2+
import numpy as np
3+
import torch
4+
5+
6+
class Numpy(np.ndarray):
7+
@classmethod
8+
def __get_validators__(cls):
9+
yield cls.validate
10+
11+
@classmethod
12+
def validate(cls, data) -> np.ndarray:
13+
if isinstance(data, cls):
14+
return data.view(np.ndarray)
15+
elif isinstance(data, np.ndarray):
16+
return data
17+
elif isinstance(data, torch.Tensor):
18+
return data.numpy()
19+
else:
20+
return np.array(data)
21+
22+
@classmethod
23+
def ndim(cls, ndim) -> Numpy:
24+
class InheritNumpy(cls):
25+
@classmethod
26+
def validate(cls, data):
27+
data = super().validate(data)
28+
if data.ndim != ndim:
29+
raise ValueError(f"Expected {ndim} dims, got {data.ndim}")
30+
return data
31+
32+
return InheritNumpy
33+
34+
@classmethod
35+
def dims(cls, dims) -> Numpy:
36+
class InheritNumpy(cls):
37+
@classmethod
38+
def validate(cls, data):
39+
data = super().validate(data)
40+
if data.ndim != len(dims):
41+
raise ValueError(
42+
f"Unexpected number of dims {data.ndim} for {dims}"
43+
)
44+
return data
45+
46+
return InheritNumpy
47+
48+
@classmethod
49+
def shape(cls, *sizes) -> Numpy:
50+
class InheritNumpy(cls):
51+
@classmethod
52+
def validate(cls, data):
53+
data = super().validate(data)
54+
for data_size, size in zip(data.shape, sizes):
55+
if size != -1 and data_size != size:
56+
raise ValueError(f"Expected size {size}, got {data_size}")
57+
return data
58+
59+
return InheritNumpy
60+
61+
@classmethod
62+
def between(cls, ge, le) -> Numpy:
63+
class InheritNumpy(cls):
64+
@classmethod
65+
def validate(cls, data):
66+
data = super().validate(data)
67+
68+
if data.min() < ge:
69+
raise ValueError(
70+
f"Expected greater than or equal to {ge}, got {data.min()}"
71+
)
72+
73+
if data.max() > le:
74+
raise ValueError(
75+
f"Expected less than or equal to {le}, got {data.max()}"
76+
)
77+
return data
78+
79+
return InheritNumpy
80+
81+
@classmethod
82+
def ge(cls, ge) -> Numpy:
83+
class InheritTensor(cls):
84+
@classmethod
85+
def validate(cls, data):
86+
data = super().validate(data)
87+
if data.min() < ge:
88+
raise ValueError(
89+
f"Expected greater than or equal to {ge}, got {data.min()}"
90+
)
91+
92+
return InheritTensor
93+
94+
@classmethod
95+
def le(cls, le) -> Numpy:
96+
class InheritTensor(cls):
97+
@classmethod
98+
def validate(cls, data):
99+
data = super().validate(data)
100+
101+
if data.max() > le:
102+
raise ValueError(
103+
f"Expected less than or equal to {le}, got {data.max()}"
104+
)
105+
return data
106+
107+
return InheritTensor
108+
109+
@classmethod
110+
def gt(cls, gt) -> Numpy:
111+
class InheritTensor(cls):
112+
@classmethod
113+
def validate(cls, data):
114+
data = super().validate(data)
115+
116+
if data.min() <= gt:
117+
raise ValueError(f"Expected greater than {gt}, got {data.min()}")
118+
119+
return InheritTensor
120+
121+
@classmethod
122+
def lt(cls, lt) -> Numpy:
123+
class InheritTensor(cls):
124+
@classmethod
125+
def validate(cls, data):
126+
data = super().validate(data)
127+
128+
if data.max() >= lt:
129+
raise ValueError(f"Expected less than {lt}, got {data.max()}")
130+
return data
131+
132+
return InheritTensor
133+
134+
@classmethod
135+
def ne(cls, ne) -> Numpy:
136+
class InheritTensor(cls):
137+
@classmethod
138+
def validate(cls, data):
139+
data = super().validate(data)
140+
141+
if (data == ne).any():
142+
raise ValueError(f"Unexpected value {ne}")
143+
return data
144+
145+
return InheritTensor
146+
147+
@classmethod
148+
def dtype(cls, dtype) -> Numpy:
149+
class InheritNumpy(cls):
150+
@classmethod
151+
def validate(cls, data):
152+
data = super().validate(data)
153+
new_data = data.astype(dtype)
154+
if not np.allclose(data, new_data, equal_nan=True):
155+
raise ValueError(f"Was unable to cast from {data.dtype} to {dtype}")
156+
return new_data
157+
158+
return InheritNumpy
159+
160+
@classmethod
161+
def float(cls) -> Numpy:
162+
return cls.dtype(np.float32)
163+
164+
@classmethod
165+
def float32(cls) -> Numpy:
166+
return cls.dtype(np.float32)
167+
168+
@classmethod
169+
def half(cls) -> Numpy:
170+
return cls.dtype(np.float16)
171+
172+
@classmethod
173+
def float16(cls):
174+
return cls.dtype(np.float16)
175+
176+
@classmethod
177+
def double(cls) -> Numpy:
178+
return cls.dtype(np.float64)
179+
180+
@classmethod
181+
def float64(cls) -> Numpy:
182+
return cls.dtype(np.float64)
183+
184+
@classmethod
185+
def int(cls) -> Numpy:
186+
return cls.dtype(np.int32)
187+
188+
@classmethod
189+
def int32(cls) -> Numpy:
190+
return cls.dtype(np.int32)
191+
192+
@classmethod
193+
def long(cls) -> Numpy:
194+
return cls.dtype(np.int64)
195+
196+
@classmethod
197+
def int64(cls) -> Numpy:
198+
return cls.dtype(np.int64)
199+
200+
@classmethod
201+
def short(cls) -> Numpy:
202+
return cls.dtype(np.int16)
203+
204+
@classmethod
205+
def int16(cls) -> Numpy:
206+
return cls.dtype(np.int16)
207+
208+
@classmethod
209+
def byte(cls) -> Numpy:
210+
return cls.dtype(np.uint8)
211+
212+
@classmethod
213+
def uint8(cls) -> Numpy:
214+
return cls.dtype(np.uint8)
215+
216+
@classmethod
217+
def bool(cls) -> Numpy:
218+
return cls.dtype(bool)
219+
220+
221+
def test_base_model():
222+
from pydantic import BaseModel
223+
from pytest import raises
224+
225+
class Test(BaseModel):
226+
images: Numpy.dims("NCHW")
227+
228+
Test(images=np.ones((10, 3, 32, 32)))
229+
230+
with raises(ValueError):
231+
Test(images=np.ones((10, 3, 32)))
232+
233+
234+
def test_validate():
235+
from pytest import raises
236+
237+
with raises(ValueError):
238+
Numpy.ndim(4).validate(np.ones((3, 4, 5)))
239+
240+
241+
def test_conversion():
242+
from pydantic import BaseModel
243+
import torch
244+
245+
class Test(BaseModel):
246+
numbers: Numpy.dims("N")
247+
248+
Test(numbers=[1.1, 2.1, 3.1])
249+
Test(numbers=torch.tensor([1.1, 2.1, 3.1]))
250+
251+
252+
def test_chaining():
253+
from pytest import raises
254+
255+
with raises(ValueError):
256+
Numpy.ndim(4).dims("NCH").validate(np.ones((3, 4, 5)))
257+
258+
with raises(ValueError):
259+
Numpy.dims("NCH").ndim(4).validate(np.ones((3, 4, 5)))
260+
261+
262+
def test_dtype():
263+
from pydantic import BaseModel
264+
from pytest import raises
265+
266+
class Test(BaseModel):
267+
numbers: Numpy.uint8()
268+
269+
Test(numbers=[1, 2, 3])
270+
271+
with raises(ValueError):
272+
Test(numbers=[1.5, 2.2, 3.2])
273+
274+
class TestBool(BaseModel):
275+
flags: Numpy.bool()
276+
277+
TestBool(flags=[True, False, True])
278+
279+
with raises(ValueError):
280+
TestBool(numbers=[1.5, 2.2, 3.2])
281+
282+
283+
def test_from_torch():
284+
import torch
285+
from pydantic import BaseModel
286+
287+
class Test(BaseModel):
288+
numbers: Numpy
289+
290+
numbers = torch.tensor([1, 2, 3])
291+
numpy_numbers = Test(numbers=numbers).numbers
292+
293+
assert type(numpy_numbers) == np.ndarray
294+
assert np.allclose(torch.from_numpy(numpy_numbers), numbers)
295+
296+
297+
def test_between():
298+
from pydantic import BaseModel
299+
from pytest import raises
300+
301+
class Test(BaseModel):
302+
numbers: Numpy.between(1, 3.5)
303+
304+
Test(numbers=[1.5, 2.2, 3.2])
305+
306+
with raises(ValueError):
307+
Test(numbers=[-1.5, 2.2, 3.2])
308+
309+
310+
def test_gt():
311+
from pydantic import BaseModel
312+
from pytest import raises
313+
314+
class Test(BaseModel):
315+
numbers: Numpy.gt(1)
316+
317+
Test(numbers=[1.5, 2.2, 3.2])
318+
319+
with raises(ValueError):
320+
Test(numbers=[1, 2.2, 3.2])

0 commit comments

Comments
 (0)