Skip to content

Commit 3329762

Browse files
committed
fix: handle extra arguments to validate in internal functions
1 parent 265389c commit 3329762

File tree

2 files changed

+21
-21
lines changed

2 files changed

+21
-21
lines changed

lantern/numpy.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def validate(cls, data, config=None, field=None) -> np.ndarray:
2424
def ndim(cls, ndim) -> Numpy:
2525
class InheritNumpy(cls):
2626
@classmethod
27-
def validate(cls, data):
27+
def validate(cls, data, config=None, field=None):
2828
data = super().validate(data)
2929
if data.ndim != ndim:
3030
raise ValueError(f"Expected {ndim} dims, got {data.ndim}")
@@ -36,7 +36,7 @@ def validate(cls, data):
3636
def dims(cls, dims) -> Numpy:
3737
class InheritNumpy(cls):
3838
@classmethod
39-
def validate(cls, data):
39+
def validate(cls, data, config=None, field=None):
4040
data = super().validate(data)
4141
if data.ndim != len(dims):
4242
raise ValueError(
@@ -50,7 +50,7 @@ def validate(cls, data):
5050
def shape(cls, *sizes) -> Numpy:
5151
class InheritNumpy(cls):
5252
@classmethod
53-
def validate(cls, data):
53+
def validate(cls, data, config=None, field=None):
5454
data = super().validate(data)
5555
for data_size, size in zip(data.shape, sizes):
5656
if size != -1 and data_size != size:
@@ -63,7 +63,7 @@ def validate(cls, data):
6363
def between(cls, ge, le) -> Numpy:
6464
class InheritNumpy(cls):
6565
@classmethod
66-
def validate(cls, data):
66+
def validate(cls, data, config=None, field=None):
6767
data = super().validate(data)
6868

6969
if data.min() < ge:
@@ -83,7 +83,7 @@ def validate(cls, data):
8383
def ge(cls, ge) -> Numpy:
8484
class InheritTensor(cls):
8585
@classmethod
86-
def validate(cls, data):
86+
def validate(cls, data, config=None, field=None):
8787
data = super().validate(data)
8888
if data.min() < ge:
8989
raise ValueError(
@@ -96,7 +96,7 @@ def validate(cls, data):
9696
def le(cls, le) -> Numpy:
9797
class InheritTensor(cls):
9898
@classmethod
99-
def validate(cls, data):
99+
def validate(cls, data, config=None, field=None):
100100
data = super().validate(data)
101101

102102
if data.max() > le:
@@ -111,7 +111,7 @@ def validate(cls, data):
111111
def gt(cls, gt) -> Numpy:
112112
class InheritTensor(cls):
113113
@classmethod
114-
def validate(cls, data):
114+
def validate(cls, data, config=None, field=None):
115115
data = super().validate(data)
116116

117117
if data.min() <= gt:
@@ -123,7 +123,7 @@ def validate(cls, data):
123123
def lt(cls, lt) -> Numpy:
124124
class InheritTensor(cls):
125125
@classmethod
126-
def validate(cls, data):
126+
def validate(cls, data, config=None, field=None):
127127
data = super().validate(data)
128128

129129
if data.max() >= lt:
@@ -136,7 +136,7 @@ def validate(cls, data):
136136
def ne(cls, ne) -> Numpy:
137137
class InheritTensor(cls):
138138
@classmethod
139-
def validate(cls, data):
139+
def validate(cls, data, config=None, field=None):
140140
data = super().validate(data)
141141

142142
if (data == ne).any():
@@ -149,7 +149,7 @@ def validate(cls, data):
149149
def dtype(cls, dtype) -> Numpy:
150150
class InheritNumpy(cls):
151151
@classmethod
152-
def validate(cls, data):
152+
def validate(cls, data, config=None, field=None):
153153
data = super().validate(data)
154154
if data.dtype == dtype:
155155
return data

lantern/tensor.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def validate(cls, data, config=None, field=None) -> torch.Tensor:
3434
def ndim(cls, ndim) -> Tensor:
3535
class InheritTensor(cls):
3636
@classmethod
37-
def validate(cls, data):
37+
def validate(cls, data, config=None, field=None):
3838
data = super().validate(data)
3939
if data.ndim != ndim:
4040
raise ValueError(f"Expected {ndim} dims, got {data.ndim}")
@@ -46,7 +46,7 @@ def validate(cls, data):
4646
def dims(cls, dims) -> Tensor:
4747
class InheritTensor(cls):
4848
@classmethod
49-
def validate(cls, data):
49+
def validate(cls, data, config=None, field=None):
5050
data = super().validate(data)
5151
if data.ndim != len(dims):
5252
raise ValueError(
@@ -60,7 +60,7 @@ def validate(cls, data):
6060
def shape(cls, *sizes) -> Tensor:
6161
class InheritTensor(cls):
6262
@classmethod
63-
def validate(cls, data):
63+
def validate(cls, data, config=None, field=None):
6464
data = super().validate(data)
6565
for data_size, size in zip(data.shape, sizes):
6666
if size != -1 and data_size != size:
@@ -73,7 +73,7 @@ def validate(cls, data):
7373
def between(cls, ge, le) -> Tensor:
7474
class InheritTensor(cls):
7575
@classmethod
76-
def validate(cls, data):
76+
def validate(cls, data, config=None, field=None):
7777
data = super().validate(data)
7878
if data.min() < ge:
7979
raise ValueError(
@@ -92,7 +92,7 @@ def validate(cls, data):
9292
def ge(cls, ge) -> Tensor:
9393
class InheritTensor(cls):
9494
@classmethod
95-
def validate(cls, data):
95+
def validate(cls, data, config=None, field=None):
9696
data = super().validate(data)
9797
if data.min() < ge:
9898
raise ValueError(
@@ -105,7 +105,7 @@ def validate(cls, data):
105105
def le(cls, le) -> Tensor:
106106
class InheritTensor(cls):
107107
@classmethod
108-
def validate(cls, data):
108+
def validate(cls, data, config=None, field=None):
109109
data = super().validate(data)
110110

111111
if data.max() > le:
@@ -120,7 +120,7 @@ def validate(cls, data):
120120
def gt(cls, gt) -> Tensor:
121121
class InheritTensor(cls):
122122
@classmethod
123-
def validate(cls, data):
123+
def validate(cls, data, config=None, field=None):
124124
data = super().validate(data)
125125

126126
if data.min() <= gt:
@@ -132,7 +132,7 @@ def validate(cls, data):
132132
def lt(cls, lt) -> Tensor:
133133
class InheritTensor(cls):
134134
@classmethod
135-
def validate(cls, data):
135+
def validate(cls, data, config=None, field=None):
136136
data = super().validate(data)
137137

138138
if data.max() >= lt:
@@ -145,7 +145,7 @@ def validate(cls, data):
145145
def ne(cls, ne) -> Tensor:
146146
class InheritTensor(cls):
147147
@classmethod
148-
def validate(cls, data):
148+
def validate(cls, data, config=None, field=None):
149149
data = super().validate(data)
150150

151151
if (data == ne).any():
@@ -158,7 +158,7 @@ def validate(cls, data):
158158
def device(cls, device) -> Tensor:
159159
class InheritTensor(cls):
160160
@classmethod
161-
def validate(cls, data):
161+
def validate(cls, data, config=None, field=None):
162162
return super().validate(data).to(device)
163163

164164
return InheritTensor
@@ -175,7 +175,7 @@ def cuda(cls) -> Tensor:
175175
def dtype(cls, dtype) -> Tensor:
176176
class InheritTensor(cls):
177177
@classmethod
178-
def validate(cls, data):
178+
def validate(cls, data, config=None, field=None):
179179
data = super().validate(data)
180180
if data.dtype == dtype:
181181
return data

0 commit comments

Comments
 (0)