Skip to content

Commit bc9b1c2

Browse files
committed
improve!: new features in numpy/tensor types
BREAKING CHANGE
1 parent b2618b5 commit bc9b1c2

File tree

2 files changed

+124
-11
lines changed

2 files changed

+124
-11
lines changed

lantern/numpy.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def validate(cls, data):
4040
return InheritNumpy
4141

4242
@classmethod
43-
def short(cls, dims):
43+
def dims(cls, dims):
4444
class InheritNumpy(cls):
4545
@classmethod
4646
def validate(cls, data):
@@ -83,15 +83,32 @@ def validate(cls, data):
8383

8484
return InheritNumpy
8585

86+
@classmethod
87+
def dtype(cls, dtype):
88+
class InheritNumpy(cls):
89+
@classmethod
90+
def validate(cls, data):
91+
data = super().validate(data)
92+
new_data = data.astype(dtype)
93+
if not np.allclose(data, new_data, equal_nan=True):
94+
raise ValueError(f"Was unable to cast from {data.dtype} to {dtype}")
95+
return new_data
96+
97+
return InheritNumpy
98+
8699

87100
def test_base_model():
88101
from pydantic import BaseModel
102+
from pytest import raises
89103

90104
class Test(BaseModel):
91-
images: Numpy.short("nchw")
105+
images: Numpy.dims("NCHW")
92106

93107
Test(images=np.ones((10, 3, 32, 32)))
94108

109+
with raises(ValueError):
110+
Test(images=np.ones((10, 3, 32)))
111+
95112

96113
def test_validate():
97114
from pytest import raises
@@ -105,7 +122,7 @@ def test_conversion():
105122
import torch
106123

107124
class Test(BaseModel):
108-
numbers: Numpy.short("N")
125+
numbers: Numpy.dims("N")
109126

110127
Test(numbers=[1.1, 2.1, 3.1])
111128
Test(numbers=torch.tensor([1.1, 2.1, 3.1]))
@@ -115,7 +132,20 @@ def test_chaining():
115132
from pytest import raises
116133

117134
with raises(ValueError):
118-
Numpy.ndim(4).short("NCH").validate(np.ones((3, 4, 5)))
135+
Numpy.ndim(4).dims("NCH").validate(np.ones((3, 4, 5)))
136+
137+
with raises(ValueError):
138+
Numpy.dims("NCH").ndim(4).validate(np.ones((3, 4, 5)))
139+
140+
141+
def test_dtype():
142+
from pydantic import BaseModel
143+
from pytest import raises
144+
145+
class Test(BaseModel):
146+
numbers: Numpy.dtype(np.uint8)
147+
148+
Test(numbers=[1, 2, 3])
119149

120150
with raises(ValueError):
121-
Numpy.short("NCH").ndim(4).validate(np.ones((3, 4, 5)))
151+
Test(numbers=[1.5, 2.2, 3.2])

lantern/tensor.py

Lines changed: 89 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import torch
23

34

@@ -12,6 +13,8 @@ def validate(cls, data):
1213
return torch.tensor(data)
1314
elif isinstance(data, torch.Tensor):
1415
return data
16+
elif isinstance(data, np.ndarray):
17+
return torch.from_numpy(data)
1518
else:
1619
return torch.as_tensor(data)
1720

@@ -28,7 +31,7 @@ def validate(cls, data):
2831
return InheritTensor
2932

3033
@classmethod
31-
def short(cls, dims):
34+
def dims(cls, dims):
3235
class InheritTensor(cls):
3336
@classmethod
3437
def validate(cls, data):
@@ -71,12 +74,70 @@ def validate(cls, data):
7174

7275
return InheritTensor
7376

77+
@classmethod
78+
def device(cls, device):
79+
class InheritTensor(cls):
80+
@classmethod
81+
def validate(cls, data):
82+
return super().validate(data).to(device)
83+
84+
return InheritTensor
85+
86+
@classmethod
87+
def cpu(cls):
88+
return cls.device(torch.device("cpu"))
89+
90+
@classmethod
91+
def cuda(cls):
92+
return cls.device(torch.device("cuda"))
93+
94+
@classmethod
95+
def dtype(cls, dtype):
96+
class InheritTensor(cls):
97+
@classmethod
98+
def validate(cls, data):
99+
data = super().validate(data)
100+
new_data = data.type(dtype)
101+
if not torch.allclose(data.float(), new_data.float(), equal_nan=True):
102+
raise ValueError(f"Was unable to cast from {data.dtype} to {dtype}")
103+
return new_data
104+
105+
return InheritTensor
106+
107+
@classmethod
108+
def float(cls):
109+
return cls.dtype(torch.float32)
110+
111+
@classmethod
112+
def half(cls):
113+
return cls.dtype(torch.float16)
114+
115+
@classmethod
116+
def double(cls):
117+
return cls.dtype(torch.float64)
118+
119+
@classmethod
120+
def int(cls):
121+
return cls.dtype(torch.int32)
122+
123+
@classmethod
124+
def long(cls):
125+
return cls.dtype(torch.int64)
126+
127+
@classmethod
128+
def short(cls):
129+
return cls.dtype(torch.int16)
130+
131+
@classmethod
132+
def uint8(cls):
133+
return cls.dtype(torch.uint8)
134+
74135

75136
def test_base_model():
76137
from pydantic import BaseModel
77138

78139
class Test(BaseModel):
79-
tensor: Tensor.short("nchw")
140+
tensor: Tensor.dims("NCHW")
80141

81142
Test(tensor=torch.ones(10, 3, 32, 32))
82143

@@ -93,8 +154,8 @@ def test_conversion():
93154
import numpy as np
94155

95156
class Test(BaseModel):
96-
numbers: Tensor.short("N")
97-
numbers2: Tensor.short("N")
157+
numbers: Tensor.dims("N")
158+
numbers2: Tensor.dims("N")
98159

99160
Test(
100161
numbers=[1.1, 2.1, 3.1],
@@ -106,7 +167,29 @@ def test_chaining():
106167
from pytest import raises
107168

108169
with raises(ValueError):
109-
Tensor.ndim(4).short("NCH").validate(torch.ones(3, 4, 5))
170+
Tensor.ndim(4).dims("NCH").validate(torch.ones(3, 4, 5))
171+
172+
with raises(ValueError):
173+
Tensor.dims("NCH").ndim(4).validate(torch.ones(3, 4, 5))
174+
175+
176+
def test_dtype():
177+
from pydantic import BaseModel
178+
from pytest import raises
179+
180+
class Test(BaseModel):
181+
numbers: Tensor.uint8()
182+
183+
Test(numbers=[1, 2, 3])
110184

111185
with raises(ValueError):
112-
Tensor.short("NCH").ndim(4).validate(torch.ones(3, 4, 5))
186+
Test(numbers=[1.5, 2.2, 3.2])
187+
188+
189+
def test_device():
190+
from pydantic import BaseModel
191+
192+
class Test(BaseModel):
193+
numbers: Tensor.float().cpu()
194+
195+
Test(numbers=[1, 2, 3])

0 commit comments

Comments
 (0)