99import copy
1010from functools import lru_cache
1111from typing import List , OrderedDict , Tuple
12+ import facto .specdb .function as fn
1213
1314import torch
1415from facto .inputgen .argtuple .gen import ArgumentTupleGenerator
1920# seed to generate identical cases every run to reproduce from bisect
2021MAX_CASES = 50
2122
22-
2323def apply_tensor_contraints (op_name : str , index : int ) -> list [object ]:
2424 tensor_constraints = [
25- cp .Dtype .In (lambda deps : [torch .int , torch .float ]),
26- cp .Dtype .NotIn (lambda deps : [torch .int64 , torch .float64 ]),
25+ cp .Dtype .In (lambda deps : [torch .int8 , torch .int16 , torch .uint8 , torch .uint16 , torch .float32 ]),
2726 cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
2827 cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
2928 cp .Rank .Ge (lambda deps : 1 ),
3029 cp .Size .Ge (lambda deps , r , d : 1 ),
3130 cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
31+ cp .Rank .Le (lambda deps : 2 ** 3 ),
3232 ]
3333
3434 match op_name :
3535 case "where.self" :
3636 if index == 0 : # condition
3737 tensor_constraints = [
3838 cp .Dtype .In (lambda deps : [torch .bool ]),
39- cp .Dtype .NotIn (lambda deps : [torch .int64 , torch .float64 ]),
4039 cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
4140 cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
4241 cp .Rank .Ge (lambda deps : 1 ),
@@ -45,28 +44,35 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
4544 ]
4645 else :
4746 tensor_constraints = [
47+ cp .Dtype .In (lambda deps : [torch .int8 , torch .int16 , torch .uint8 , torch .uint16 , torch .float32 ]),
48+ cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
49+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
50+ cp .Rank .Ge (lambda deps : 1 ),
51+ cp .Size .Ge (lambda deps , r , d : 1 ),
52+ cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
53+ ]
54+ case "embedding.default" :
55+ tensor_constraints = [
4856 cp .Dtype .In (lambda deps : [torch .float , torch .int ]),
4957 cp .Dtype .NotIn (lambda deps : [torch .int64 , torch .float64 ]),
5058 cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
5159 cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
5260 cp .Rank .Ge (lambda deps : 1 ),
5361 cp .Size .Ge (lambda deps , r , d : 1 ),
5462 cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
55- ]
63+ ]
5664 case "sigmoid.default" :
5765 tensor_constraints .extend (
5866 [
59- cp .Dtype .In (lambda deps : [torch .float ]),
60- cp .Rank .Le (lambda deps : 2 ** 2 ),
67+ cp .Dtype .In (lambda deps : [torch .float32 ]),
6168 cp .Value .Ge (lambda deps , dtype , struct : - 2 ),
6269 cp .Value .Le (lambda deps , dtype , struct : 2 ),
6370 ]
6471 )
6572 case "rsqrt.default" :
6673 tensor_constraints .extend (
6774 [
68- cp .Dtype .In (lambda deps : [torch .float ]),
69- cp .Rank .Le (lambda deps : 2 ** 2 ),
75+ cp .Dtype .In (lambda deps : [torch .float32 ]),
7076 cp .Value .Gt (
7177 lambda deps , dtype , struct : 0
7278 ), # only generate real numbers
@@ -76,14 +82,12 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
7682 case "mean.dim" :
7783 tensor_constraints .extend (
7884 [
79- cp .Dtype .In (lambda deps : [torch .float ]),
80- cp .Rank .Le (lambda deps : 2 ** 2 ),
85+ cp .Dtype .In (lambda deps : [torch .float32 ]),
8186 ]
8287 )
8388 case "exp.default" :
8489 tensor_constraints .extend (
8590 [
86- cp .Rank .Le (lambda deps : 2 ** 3 ),
8791 cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 2 )),
8892 cp .Value .Le (lambda deps , dtype , struct : 2 ** 2 ),
8993 ]
@@ -96,20 +100,82 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
96100 cp .Value .Le (lambda deps , dtype , struct : 2 ),
97101 ]
98102 )
99- case _ :
103+ case "constant_pad_nd.default" :
100104 tensor_constraints .extend (
101105 [
102- cp .Rank .Le (lambda deps : 2 ** 2 ),
106+ cp .Dtype .In (lambda deps : [torch .float32 ]),
107+ cp .Size .Le (lambda deps , r , d : 2 ** 2 ),
108+ ]
109+ )
110+ case "avg_pool2d.default" :
111+ tensor_constraints .extend (
112+ [
113+ cp .Rank .Eq (lambda deps : 4 ),
114+ ]
115+ )
116+ case "bmm.default" | "addmm.default" | "mm.default" :
117+ tensor_constraints .extend (
118+ [
119+ cp .Dtype .Eq (lambda deps : torch .float32 ),
120+ cp .Size .Le (lambda deps , r , d : 2 ** 2 ),
121+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
122+ ]
123+ )
124+ case "div.Tensor" :
125+ tensor_constraints .extend (
126+ [
127+ cp .Value .Ne (lambda deps , dtype , struct : 0 ),
128+ ]
129+ )
130+ case "div.Tensor_mode" | "minimum.default" :
131+ if index == 0 :
132+ tensor_constraints = [
133+ cp .Dtype .In (lambda deps : [torch .int64 , torch .int32 , torch .float32 ]),
134+ cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
135+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
136+ cp .Rank .Ge (lambda deps : 1 ),
137+ cp .Size .Ge (lambda deps , r , d : 1 ),
138+ cp .Size .Le (lambda deps , r , d : 2 ** 2 ),
139+ ]
140+ else :
141+ tensor_constraints = [
142+ cp .Dtype .In (lambda deps : [torch .int64 , torch .int32 , torch .float32 ]),
143+ cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
144+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
145+ cp .Rank .Ge (lambda deps : 1 ),
146+ cp .Rank .Eq (lambda deps : deps [0 ].dim ()),
147+ cp .Size .Eq (lambda deps , r , d : fn .safe_size (deps [0 ], d )),
148+ ]
149+ case "_native_batch_norm_legit_no_training.default" :
150+ tensor_constraints .extend (
151+ [
152+ cp .Rank .Le (lambda deps : 3 ),
153+ ],
154+ )
155+ case "reciprocal.default" :
156+ tensor_constraints = [
157+ cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 2 )),
158+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 2 ),
159+ cp .Size .Le (lambda deps , r , d : 2 ** 3 ),
160+ ]
161+ case "_softmax.default" :
162+ tensor_constraints .extend (
163+ [
164+ cp .Dtype .Eq (lambda deps : torch .float32 ),
165+ cp .Size .Le (lambda deps , r , d : 2 ** 2 ),
103166 ]
104167 )
168+ case _:
169+ pass
105170 return tensor_constraints
106171
107172
108173def apply_scalar_contraints (op_name : str ) -> list [ScalarDtype ]:
109174 match op_name :
110- case "add.Scalar" | "sub.Scalar" | "mul.Scalar" | "div.Scalar" :
175+ case "add.Scalar" | "sub.Scalar" | "mul.Scalar" | "div.Scalar" | "constant_pad_nd.default" :
176+ return [ScalarDtype .int ]
177+ case "full.default" :
111178 return [ScalarDtype .int ]
112-
113179 case _:
114180 return [ScalarDtype .float , ScalarDtype .int ]
115181
@@ -149,6 +215,12 @@ def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, s
149215 cp .Dtype .In (lambda deps : apply_scalar_contraints (op_name )),
150216 ]
151217 )
218+ if in_spec .name == "dtype" : # full.default
219+ spec .inspec [index ].constraints .extend (
220+ [
221+ cp .Dtype .In (lambda deps : [torch .long , torch .float ]),
222+ ]
223+ )
152224 elif in_spec .type .is_tensor ():
153225 spec .inspec [index ].constraints .extend (
154226 apply_tensor_contraints (op_name , index )
@@ -166,6 +238,29 @@ def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, s
166238 cp .Dtype .In (lambda deps : [torch .bool ]),
167239 ]
168240 )
241+ elif in_spec .type .is_length_list ():
242+ spec .inspec [index ].constraints .extend (
243+ [
244+ cp .Value .Ge (lambda deps , dtype , struct : 0 ),
245+ ]
246+ )
247+ if op_name == "avg_pool2d.default" :
248+ spec .inspec [index ].constraints .extend (
249+ [
250+ cp .Length .Eq (lambda deps : 2 ),
251+ ]
252+ )
253+ elif in_spec .type .is_shape ():
254+ spec .inspec [index ].constraints .extend (
255+ [
256+ cp .Rank .Ge (lambda deps : 1 ),
257+ cp .Rank .Le (lambda deps : 2 ** 2 ),
258+ cp .Value .Gt (lambda deps , dtype , struct : 0 ),
259+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 2 ),
260+ cp .Size .Ge (lambda deps , r , d : 1 ),
261+ cp .Size .Le (lambda deps , r , d : 2 ** 2 ),
262+ ]
263+ )
169264
170265 return [
171266 (posargs , inkwargs )
0 commit comments