@@ -34,7 +34,7 @@ def validate(cls, data, config=None, field=None) -> torch.Tensor:
34
34
def ndim (cls , ndim ) -> Tensor :
35
35
class InheritTensor (cls ):
36
36
@classmethod
37
- def validate (cls , data ):
37
+ def validate (cls , data , config = None , field = None ):
38
38
data = super ().validate (data )
39
39
if data .ndim != ndim :
40
40
raise ValueError (f"Expected { ndim } dims, got { data .ndim } " )
@@ -46,7 +46,7 @@ def validate(cls, data):
46
46
def dims (cls , dims ) -> Tensor :
47
47
class InheritTensor (cls ):
48
48
@classmethod
49
- def validate (cls , data ):
49
+ def validate (cls , data , config = None , field = None ):
50
50
data = super ().validate (data )
51
51
if data .ndim != len (dims ):
52
52
raise ValueError (
@@ -60,7 +60,7 @@ def validate(cls, data):
60
60
def shape (cls , * sizes ) -> Tensor :
61
61
class InheritTensor (cls ):
62
62
@classmethod
63
- def validate (cls , data ):
63
+ def validate (cls , data , config = None , field = None ):
64
64
data = super ().validate (data )
65
65
for data_size , size in zip (data .shape , sizes ):
66
66
if size != - 1 and data_size != size :
@@ -73,7 +73,7 @@ def validate(cls, data):
73
73
def between (cls , ge , le ) -> Tensor :
74
74
class InheritTensor (cls ):
75
75
@classmethod
76
- def validate (cls , data ):
76
+ def validate (cls , data , config = None , field = None ):
77
77
data = super ().validate (data )
78
78
if data .min () < ge :
79
79
raise ValueError (
@@ -92,7 +92,7 @@ def validate(cls, data):
92
92
def ge (cls , ge ) -> Tensor :
93
93
class InheritTensor (cls ):
94
94
@classmethod
95
- def validate (cls , data ):
95
+ def validate (cls , data , config = None , field = None ):
96
96
data = super ().validate (data )
97
97
if data .min () < ge :
98
98
raise ValueError (
@@ -105,7 +105,7 @@ def validate(cls, data):
105
105
def le (cls , le ) -> Tensor :
106
106
class InheritTensor (cls ):
107
107
@classmethod
108
- def validate (cls , data ):
108
+ def validate (cls , data , config = None , field = None ):
109
109
data = super ().validate (data )
110
110
111
111
if data .max () > le :
@@ -120,7 +120,7 @@ def validate(cls, data):
120
120
def gt (cls , gt ) -> Tensor :
121
121
class InheritTensor (cls ):
122
122
@classmethod
123
- def validate (cls , data ):
123
+ def validate (cls , data , config = None , field = None ):
124
124
data = super ().validate (data )
125
125
126
126
if data .min () <= gt :
@@ -132,7 +132,7 @@ def validate(cls, data):
132
132
def lt (cls , lt ) -> Tensor :
133
133
class InheritTensor (cls ):
134
134
@classmethod
135
- def validate (cls , data ):
135
+ def validate (cls , data , config = None , field = None ):
136
136
data = super ().validate (data )
137
137
138
138
if data .max () >= lt :
@@ -145,7 +145,7 @@ def validate(cls, data):
145
145
def ne (cls , ne ) -> Tensor :
146
146
class InheritTensor (cls ):
147
147
@classmethod
148
- def validate (cls , data ):
148
+ def validate (cls , data , config = None , field = None ):
149
149
data = super ().validate (data )
150
150
151
151
if (data == ne ).any ():
@@ -158,7 +158,7 @@ def validate(cls, data):
158
158
def device (cls , device ) -> Tensor :
159
159
class InheritTensor (cls ):
160
160
@classmethod
161
- def validate (cls , data ):
161
+ def validate (cls , data , config = None , field = None ):
162
162
return super ().validate (data ).to (device )
163
163
164
164
return InheritTensor
@@ -175,7 +175,7 @@ def cuda(cls) -> Tensor:
175
175
def dtype (cls , dtype ) -> Tensor :
176
176
class InheritTensor (cls ):
177
177
@classmethod
178
- def validate (cls , data ):
178
+ def validate (cls , data , config = None , field = None ):
179
179
data = super ().validate (data )
180
180
if data .dtype == dtype :
181
181
return data
0 commit comments