Skip to content

Commit ca24380

Browse files
committed
improve: fix mnist download and numpy type
1 parent 2c85e19 commit ca24380

File tree

8 files changed

+484
-376
lines changed

8 files changed

+484
-376
lines changed

lantern/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from lantern import functional
22

3-
from lantern.numpy.figure_to_numpy import figure_to_numpy
4-
from lantern.numpy.numpy_seed import numpy_seed
3+
from lantern.numpy_from_matplotlib_figure import numpy_from_matplotlib_figure
4+
from lantern.numpy_seed import numpy_seed
55

66
from lantern.functional_base import FunctionalBase
77
from lantern.tensor import Tensor
8+
from lantern.numpy import Numpy
89
from lantern.epochs import Epochs
910
from lantern.metric import Metric
1011
from lantern.metric_table import MetricTable

lantern/numpy.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import numpy as np
2+
3+
4+
class Numpy(np.ndarray):
5+
@classmethod
6+
def __get_validators__(cls):
7+
yield cls.validate
8+
9+
@classmethod
10+
def validate(cls, data, values, field, config):
11+
return data
12+
13+
@classmethod
14+
def ndim(cls, ndim):
15+
class InheritNumpy(Numpy):
16+
@classmethod
17+
def validate(cls, data):
18+
if data.ndim != ndim:
19+
raise ValueError(f"Expected {ndim} dims, got {data.ndim}")
20+
return data
21+
22+
return InheritNumpy
23+
24+
@classmethod
25+
def short(cls, dims):
26+
class InheritNumpy(Numpy):
27+
@classmethod
28+
def validate(cls, data):
29+
if data.ndim != len(dims):
30+
raise ValueError(
31+
f"Unexpected number of dims {data.ndim} for {dims}"
32+
)
33+
return data
34+
35+
return InheritNumpy
36+
37+
@classmethod
38+
def shape(cls, *sizes):
39+
class InheritNumpy(Numpy):
40+
@classmethod
41+
def validate(cls, data):
42+
for data_size, size in zip(data.shape, sizes):
43+
if size != -1 and data_size != size:
44+
raise ValueError(f"Expected size {size}, got {data_size}")
45+
return data
46+
47+
return InheritNumpy
48+
49+
@classmethod
50+
def between(cls, geq, leq):
51+
class InheritNumpy(Numpy):
52+
@classmethod
53+
def validate(cls, data):
54+
data_min = data.min()
55+
if data_min < geq:
56+
raise ValueError(f"Expected min value {geq}, got {data_min}")
57+
58+
data_max = data.min()
59+
if data_max > leq:
60+
raise ValueError(f"Expected max value {leq}, got {data_max}")
61+
return data
62+
63+
return InheritNumpy
64+
65+
66+
def test_base_model():
67+
from pydantic import BaseModel
68+
69+
class Test(BaseModel):
70+
images: Numpy.short("nchw")
71+
72+
Test(images=np.ones((10, 3, 32, 32)))
73+
74+
75+
def test_validate():
76+
from pytest import raises
77+
78+
with raises(ValueError):
79+
Numpy.ndim(4).validate(np.ones((3, 4, 5)))

lantern/numpy/__init__.py

Lines changed: 0 additions & 2 deletions
This file was deleted.

lantern/numpy/figure_to_numpy.py renamed to lantern/numpy_from_matplotlib_figure.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import cv2
44

55

6-
def figure_to_numpy(figure):
6+
def numpy_from_matplotlib_figure(figure):
77
buffer = io.BytesIO()
88
figure.savefig(buffer, format="png", dpi=90, bbox_inches="tight")
99
buffer.seek(0)
File renamed without changes.

0 commit comments

Comments
 (0)