1010from functools import lru_cache
1111from typing import List , OrderedDict , Tuple
1212
13+ import facto .specdb .function as fn
14+
1315import torch
1416from facto .inputgen .argtuple .gen import ArgumentTupleGenerator
1517from facto .inputgen .specs .model import ConstraintProducer as cp
2224
2325def apply_tensor_contraints (op_name : str , index : int ) -> list [object ]:
2426 tensor_constraints = [
25- cp .Dtype .In (lambda deps : [torch .int , torch .float ]),
26- cp .Dtype .NotIn (lambda deps : [torch .int64 , torch .float64 ]),
27+ cp .Dtype .In (
28+ lambda deps : [
29+ torch .int8 ,
30+ torch .int16 ,
31+ torch .uint8 ,
32+ torch .uint16 ,
33+ torch .float32 ,
34+ ]
35+ ),
2736 cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
2837 cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
2938 cp .Rank .Ge (lambda deps : 1 ),
3039 cp .Size .Ge (lambda deps , r , d : 1 ),
3140 cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
41+ cp .Rank .Le (lambda deps : 2 ** 3 ),
3242 ]
3343
3444 match op_name :
3545 case "where.self" :
3646 if index == 0 : # condition
3747 tensor_constraints = [
3848 cp .Dtype .In (lambda deps : [torch .bool ]),
39- cp .Dtype .NotIn (lambda deps : [torch .int64 , torch .float64 ]),
4049 cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
4150 cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
4251 cp .Rank .Ge (lambda deps : 1 ),
@@ -45,28 +54,43 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
4554 ]
4655 else :
4756 tensor_constraints = [
48- cp .Dtype .In (lambda deps : [torch .float , torch .int ]),
49- cp .Dtype .NotIn (lambda deps : [torch .int64 , torch .float64 ]),
57+ cp .Dtype .In (
58+ lambda deps : [
59+ torch .int8 ,
60+ torch .int16 ,
61+ torch .uint8 ,
62+ torch .uint16 ,
63+ torch .float32 ,
64+ ]
65+ ),
5066 cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
5167 cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
5268 cp .Rank .Ge (lambda deps : 1 ),
5369 cp .Size .Ge (lambda deps , r , d : 1 ),
5470 cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
5571 ]
72+ case "embedding.default" :
73+ tensor_constraints = [
74+ cp .Dtype .In (lambda deps : [torch .float , torch .int ]),
75+ cp .Dtype .NotIn (lambda deps : [torch .int64 , torch .float64 ]),
76+ cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
77+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
78+ cp .Rank .Ge (lambda deps : 1 ),
79+ cp .Size .Ge (lambda deps , r , d : 1 ),
80+ cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
81+ ]
5682 case "sigmoid.default" :
5783 tensor_constraints .extend (
5884 [
59- cp .Dtype .In (lambda deps : [torch .float ]),
60- cp .Rank .Le (lambda deps : 2 ** 2 ),
85+ cp .Dtype .In (lambda deps : [torch .float32 ]),
6186 cp .Value .Ge (lambda deps , dtype , struct : - 2 ),
6287 cp .Value .Le (lambda deps , dtype , struct : 2 ),
6388 ]
6489 )
6590 case "rsqrt.default" :
6691 tensor_constraints .extend (
6792 [
68- cp .Dtype .In (lambda deps : [torch .float ]),
69- cp .Rank .Le (lambda deps : 2 ** 2 ),
93+ cp .Dtype .In (lambda deps : [torch .float32 ]),
7094 cp .Value .Gt (
7195 lambda deps , dtype , struct : 0
7296 ), # only generate real numbers
@@ -76,14 +100,12 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
76100 case "mean.dim" :
77101 tensor_constraints .extend (
78102 [
79- cp .Dtype .In (lambda deps : [torch .float ]),
80- cp .Rank .Le (lambda deps : 2 ** 2 ),
103+ cp .Dtype .In (lambda deps : [torch .float32 ]),
81104 ]
82105 )
83106 case "exp.default" :
84107 tensor_constraints .extend (
85108 [
86- cp .Rank .Le (lambda deps : 2 ** 3 ),
87109 cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 2 )),
88110 cp .Value .Le (lambda deps , dtype , struct : 2 ** 2 ),
89111 ]
@@ -96,26 +118,96 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
96118 cp .Value .Le (lambda deps , dtype , struct : 2 ),
97119 ]
98120 )
99- case _ :
121+ case "constant_pad_nd.default" :
100122 tensor_constraints .extend (
101123 [
102- cp .Rank .Le (lambda deps : 2 ** 2 ),
124+ cp .Dtype .In (lambda deps : [torch .float32 ]),
125+ cp .Size .Le (lambda deps , r , d : 2 ** 2 ),
126+ ]
127+ )
128+ case "avg_pool2d.default" :
129+ tensor_constraints .extend (
130+ [
131+ cp .Rank .Eq (lambda deps : 4 ),
132+ ]
133+ )
134+ case "bmm.default" | "addmm.default" | "mm.default" :
135+ tensor_constraints .extend (
136+ [
137+ cp .Dtype .Eq (lambda deps : torch .float ),
138+ cp .Size .Le (lambda deps , r , d : 2 ** 2 ),
139+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
103140 ]
104141 )
142+ case "div.Tensor" :
143+ tensor_constraints .extend (
144+ [
145+ cp .Value .Ne (lambda deps , dtype , struct : 0 ),
146+ ]
147+ )
148+ case "div.Tensor_mode" | "minimum.default" :
149+ if index == 0 :
150+ tensor_constraints = [
151+ cp .Dtype .In (lambda deps : [torch .int64 , torch .int32 , torch .float32 ]),
152+ cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
153+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
154+ cp .Rank .Ge (lambda deps : 1 ),
155+ cp .Size .Ge (lambda deps , r , d : 1 ),
156+ cp .Size .Le (lambda deps , r , d : 2 ** 2 ),
157+ ]
158+ else :
159+ tensor_constraints = [
160+ cp .Dtype .In (lambda deps : [torch .int64 , torch .int32 , torch .float32 ]),
161+ cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
162+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
163+ cp .Rank .Ge (lambda deps : 1 ),
164+ cp .Rank .Eq (lambda deps : deps [0 ].dim ()),
165+ cp .Size .Eq (lambda deps , r , d : fn .safe_size (deps [0 ], d )),
166+ ]
167+ case "_native_batch_norm_legit_no_training.default" :
168+ tensor_constraints .extend (
169+ [
170+ cp .Rank .Le (lambda deps : 3 ),
171+ ],
172+ )
173+ case "reciprocal.default" :
174+ tensor_constraints = [
175+ cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 2 )),
176+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 2 ),
177+ cp .Size .Le (lambda deps , r , d : 2 ** 3 ),
178+ ]
179+ case "_softmax.default" :
180+ tensor_constraints .extend (
181+ [
182+ cp .Dtype .Eq (lambda deps : torch .float32 ),
183+ cp .Size .Le (lambda deps , r , d : 2 ** 2 ),
184+ ]
185+ )
186+ case _:
187+ pass
105188 return tensor_constraints
106189
107190
108191def apply_scalar_contraints (op_name : str ) -> list [ScalarDtype ]:
109192 match op_name :
110- case "add.Scalar" | "sub.Scalar" | "mul.Scalar" | "div.Scalar" :
193+ case (
194+ "add.Scalar"
195+ | "sub.Scalar"
196+ | "mul.Scalar"
197+ | "div.Scalar"
198+ | "constant_pad_nd.default"
199+ ):
200+ return [ScalarDtype .int ]
201+ case "full.default" :
111202 return [ScalarDtype .int ]
112-
113203 case _:
114204 return [ScalarDtype .float , ScalarDtype .int ]
115205
116206
117207@lru_cache (maxsize = None )
118- def facto_testcase_gen (op_name : str ) -> List [Tuple [List [str ], OrderedDict [str , str ]]]:
208+ def facto_testcase_gen ( # noqa: C901
209+ op_name : str ,
210+ ) -> List [Tuple [List [str ], OrderedDict [str , str ]]]:
119211 # minimal example to test add.Tensor using FACTO
120212 spec = SpecDictDB [op_name ]
121213
@@ -149,6 +241,12 @@ def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, s
149241 cp .Dtype .In (lambda deps : apply_scalar_contraints (op_name )),
150242 ]
151243 )
244+ if in_spec .name == "dtype" : # full.default
245+ spec .inspec [index ].constraints .extend (
246+ [
247+ cp .Dtype .In (lambda deps : [torch .long , torch .float ]),
248+ ]
249+ )
152250 elif in_spec .type .is_tensor ():
153251 spec .inspec [index ].constraints .extend (
154252 apply_tensor_contraints (op_name , index )
@@ -166,6 +264,29 @@ def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, s
166264 cp .Dtype .In (lambda deps : [torch .bool ]),
167265 ]
168266 )
267+ elif in_spec .type .is_length_list ():
268+ spec .inspec [index ].constraints .extend (
269+ [
270+ cp .Value .Ge (lambda deps , dtype , struct : 0 ),
271+ ]
272+ )
273+ if op_name == "avg_pool2d.default" :
274+ spec .inspec [index ].constraints .extend (
275+ [
276+ cp .Length .Eq (lambda deps : 2 ),
277+ ]
278+ )
279+ elif in_spec .type .is_shape ():
280+ spec .inspec [index ].constraints .extend (
281+ [
282+ cp .Rank .Ge (lambda deps : 1 ),
283+ cp .Rank .Le (lambda deps : 2 ** 2 ),
284+ cp .Value .Gt (lambda deps , dtype , struct : 0 ),
285+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 2 ),
286+ cp .Size .Ge (lambda deps , r , d : 1 ),
287+ cp .Size .Le (lambda deps , r , d : 2 ** 2 ),
288+ ]
289+ )
169290
170291 return [
171292 (posargs , inkwargs )
0 commit comments