10
10
from functools import lru_cache
11
11
from typing import List , OrderedDict , Tuple
12
12
13
+ import facto .specdb .function as fn
14
+
13
15
import torch
14
16
from facto .inputgen .argtuple .gen import ArgumentTupleGenerator
15
17
from facto .inputgen .specs .model import ConstraintProducer as cp
22
24
23
25
def apply_tensor_contraints (op_name : str , index : int ) -> list [object ]:
24
26
tensor_constraints = [
25
- cp .Dtype .In (lambda deps : [torch .int , torch .float ]),
26
- cp .Dtype .NotIn (lambda deps : [torch .int64 , torch .float64 ]),
27
+ cp .Dtype .In (
28
+ lambda deps : [
29
+ torch .int8 ,
30
+ torch .int16 ,
31
+ torch .uint8 ,
32
+ torch .uint16 ,
33
+ torch .float32 ,
34
+ ]
35
+ ),
27
36
cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
28
37
cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
29
38
cp .Rank .Ge (lambda deps : 1 ),
30
39
cp .Size .Ge (lambda deps , r , d : 1 ),
31
40
cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
41
+ cp .Rank .Le (lambda deps : 2 ** 3 ),
32
42
]
33
43
34
44
match op_name :
35
45
case "where.self" :
36
46
if index == 0 : # condition
37
47
tensor_constraints = [
38
48
cp .Dtype .In (lambda deps : [torch .bool ]),
39
- cp .Dtype .NotIn (lambda deps : [torch .int64 , torch .float64 ]),
40
49
cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
41
50
cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
42
51
cp .Rank .Ge (lambda deps : 1 ),
@@ -45,28 +54,43 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
45
54
]
46
55
else :
47
56
tensor_constraints = [
48
- cp .Dtype .In (lambda deps : [torch .float , torch .int ]),
49
- cp .Dtype .NotIn (lambda deps : [torch .int64 , torch .float64 ]),
57
+ cp .Dtype .In (
58
+ lambda deps : [
59
+ torch .int8 ,
60
+ torch .int16 ,
61
+ torch .uint8 ,
62
+ torch .uint16 ,
63
+ torch .float32 ,
64
+ ]
65
+ ),
50
66
cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
51
67
cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
52
68
cp .Rank .Ge (lambda deps : 1 ),
53
69
cp .Size .Ge (lambda deps , r , d : 1 ),
54
70
cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
55
71
]
72
+ case "embedding.default" :
73
+ tensor_constraints = [
74
+ cp .Dtype .In (lambda deps : [torch .float , torch .int ]),
75
+ cp .Dtype .NotIn (lambda deps : [torch .int64 , torch .float64 ]),
76
+ cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
77
+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
78
+ cp .Rank .Ge (lambda deps : 1 ),
79
+ cp .Size .Ge (lambda deps , r , d : 1 ),
80
+ cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
81
+ ]
56
82
case "sigmoid.default" :
57
83
tensor_constraints .extend (
58
84
[
59
- cp .Dtype .In (lambda deps : [torch .float ]),
60
- cp .Rank .Le (lambda deps : 2 ** 2 ),
85
+ cp .Dtype .In (lambda deps : [torch .float32 ]),
61
86
cp .Value .Ge (lambda deps , dtype , struct : - 2 ),
62
87
cp .Value .Le (lambda deps , dtype , struct : 2 ),
63
88
]
64
89
)
65
90
case "rsqrt.default" :
66
91
tensor_constraints .extend (
67
92
[
68
- cp .Dtype .In (lambda deps : [torch .float ]),
69
- cp .Rank .Le (lambda deps : 2 ** 2 ),
93
+ cp .Dtype .In (lambda deps : [torch .float32 ]),
70
94
cp .Value .Gt (
71
95
lambda deps , dtype , struct : 0
72
96
), # only generate real numbers
@@ -76,14 +100,12 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
76
100
case "mean.dim" :
77
101
tensor_constraints .extend (
78
102
[
79
- cp .Dtype .In (lambda deps : [torch .float ]),
80
- cp .Rank .Le (lambda deps : 2 ** 2 ),
103
+ cp .Dtype .In (lambda deps : [torch .float32 ]),
81
104
]
82
105
)
83
106
case "exp.default" :
84
107
tensor_constraints .extend (
85
108
[
86
- cp .Rank .Le (lambda deps : 2 ** 3 ),
87
109
cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 2 )),
88
110
cp .Value .Le (lambda deps , dtype , struct : 2 ** 2 ),
89
111
]
@@ -96,26 +118,96 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
96
118
cp .Value .Le (lambda deps , dtype , struct : 2 ),
97
119
]
98
120
)
99
- case _ :
121
+ case "constant_pad_nd.default" :
100
122
tensor_constraints .extend (
101
123
[
102
- cp .Rank .Le (lambda deps : 2 ** 2 ),
124
+ cp .Dtype .In (lambda deps : [torch .float32 ]),
125
+ cp .Size .Le (lambda deps , r , d : 2 ** 2 ),
126
+ ]
127
+ )
128
+ case "avg_pool2d.default" :
129
+ tensor_constraints .extend (
130
+ [
131
+ cp .Rank .Eq (lambda deps : 4 ),
132
+ ]
133
+ )
134
+ case "bmm.default" | "addmm.default" | "mm.default" :
135
+ tensor_constraints .extend (
136
+ [
137
+ cp .Dtype .Eq (lambda deps : torch .float ),
138
+ cp .Size .Le (lambda deps , r , d : 2 ** 2 ),
139
+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
103
140
]
104
141
)
142
+ case "div.Tensor" :
143
+ tensor_constraints .extend (
144
+ [
145
+ cp .Value .Ne (lambda deps , dtype , struct : 0 ),
146
+ ]
147
+ )
148
+ case "div.Tensor_mode" | "minimum.default" :
149
+ if index == 0 :
150
+ tensor_constraints = [
151
+ cp .Dtype .In (lambda deps : [torch .int64 , torch .int32 , torch .float32 ]),
152
+ cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
153
+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
154
+ cp .Rank .Ge (lambda deps : 1 ),
155
+ cp .Size .Ge (lambda deps , r , d : 1 ),
156
+ cp .Size .Le (lambda deps , r , d : 2 ** 2 ),
157
+ ]
158
+ else :
159
+ tensor_constraints = [
160
+ cp .Dtype .In (lambda deps : [torch .int64 , torch .int32 , torch .float32 ]),
161
+ cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
162
+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
163
+ cp .Rank .Ge (lambda deps : 1 ),
164
+ cp .Rank .Eq (lambda deps : deps [0 ].dim ()),
165
+ cp .Size .Eq (lambda deps , r , d : fn .safe_size (deps [0 ], d )),
166
+ ]
167
+ case "_native_batch_norm_legit_no_training.default" :
168
+ tensor_constraints .extend (
169
+ [
170
+ cp .Rank .Le (lambda deps : 3 ),
171
+ ],
172
+ )
173
+ case "reciprocal.default" :
174
+ tensor_constraints = [
175
+ cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 2 )),
176
+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 2 ),
177
+ cp .Size .Le (lambda deps , r , d : 2 ** 3 ),
178
+ ]
179
+ case "_softmax.default" :
180
+ tensor_constraints .extend (
181
+ [
182
+ cp .Dtype .Eq (lambda deps : torch .float32 ),
183
+ cp .Size .Le (lambda deps , r , d : 2 ** 2 ),
184
+ ]
185
+ )
186
+ case _:
187
+ pass
105
188
return tensor_constraints
106
189
107
190
108
191
def apply_scalar_contraints (op_name : str ) -> list [ScalarDtype ]:
109
192
match op_name :
110
- case "add.Scalar" | "sub.Scalar" | "mul.Scalar" | "div.Scalar" :
193
+ case (
194
+ "add.Scalar"
195
+ | "sub.Scalar"
196
+ | "mul.Scalar"
197
+ | "div.Scalar"
198
+ | "constant_pad_nd.default"
199
+ ):
200
+ return [ScalarDtype .int ]
201
+ case "full.default" :
111
202
return [ScalarDtype .int ]
112
-
113
203
case _:
114
204
return [ScalarDtype .float , ScalarDtype .int ]
115
205
116
206
117
207
@lru_cache (maxsize = None )
118
- def facto_testcase_gen (op_name : str ) -> List [Tuple [List [str ], OrderedDict [str , str ]]]:
208
+ def facto_testcase_gen ( # noqa: C901
209
+ op_name : str ,
210
+ ) -> List [Tuple [List [str ], OrderedDict [str , str ]]]:
119
211
# minimal example to test add.Tensor using FACTO
120
212
spec = SpecDictDB [op_name ]
121
213
@@ -149,6 +241,12 @@ def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, s
149
241
cp .Dtype .In (lambda deps : apply_scalar_contraints (op_name )),
150
242
]
151
243
)
244
+ if in_spec .name == "dtype" : # full.default
245
+ spec .inspec [index ].constraints .extend (
246
+ [
247
+ cp .Dtype .In (lambda deps : [torch .long , torch .float ]),
248
+ ]
249
+ )
152
250
elif in_spec .type .is_tensor ():
153
251
spec .inspec [index ].constraints .extend (
154
252
apply_tensor_contraints (op_name , index )
@@ -166,6 +264,29 @@ def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, s
166
264
cp .Dtype .In (lambda deps : [torch .bool ]),
167
265
]
168
266
)
267
+ elif in_spec .type .is_length_list ():
268
+ spec .inspec [index ].constraints .extend (
269
+ [
270
+ cp .Value .Ge (lambda deps , dtype , struct : 0 ),
271
+ ]
272
+ )
273
+ if op_name == "avg_pool2d.default" :
274
+ spec .inspec [index ].constraints .extend (
275
+ [
276
+ cp .Length .Eq (lambda deps : 2 ),
277
+ ]
278
+ )
279
+ elif in_spec .type .is_shape ():
280
+ spec .inspec [index ].constraints .extend (
281
+ [
282
+ cp .Rank .Ge (lambda deps : 1 ),
283
+ cp .Rank .Le (lambda deps : 2 ** 2 ),
284
+ cp .Value .Gt (lambda deps , dtype , struct : 0 ),
285
+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 2 ),
286
+ cp .Size .Ge (lambda deps , r , d : 1 ),
287
+ cp .Size .Le (lambda deps , r , d : 2 ** 2 ),
288
+ ]
289
+ )
169
290
170
291
return [
171
292
(posargs , inkwargs )
0 commit comments