2020MAX_CASES = 50
2121
2222
23- def apply_tensor_contraints (op_name : str , tensor_constraints : list [ object ] ) -> None :
24- additional_tensor_constraints = [
23+ def apply_tensor_contraints (op_name : str , index : int ) -> list [ object ] :
24+ tensor_constraints = [
2525 cp .Dtype .In (lambda deps : [torch .int , torch .float ]),
2626 cp .Dtype .NotIn (lambda deps : [torch .int64 , torch .float64 ]),
2727 cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
@@ -33,17 +33,28 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N
3333
3434 match op_name :
3535 case "where.self" :
36- additional_tensor_constraints = [
37- cp .Dtype .In (lambda deps : [torch .float , torch .int , torch .bool ]),
38- cp .Dtype .NotIn (lambda deps : [torch .int64 , torch .float64 ]),
39- cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
40- cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
41- cp .Rank .Ge (lambda deps : 1 ),
42- cp .Size .Ge (lambda deps , r , d : 1 ),
43- cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
44- ]
36+ if index == 0 : # condition
37+ tensor_constraints = [
38+ cp .Dtype .In (lambda deps : [torch .bool ]),
39+ cp .Dtype .NotIn (lambda deps : [torch .int64 , torch .float64 ]),
40+ cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
41+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
42+ cp .Rank .Ge (lambda deps : 1 ),
43+ cp .Size .Ge (lambda deps , r , d : 1 ),
44+ cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
45+ ]
46+ else :
47+ tensor_constraints = [
48+ cp .Dtype .In (lambda deps : [torch .float , torch .int ]),
49+ cp .Dtype .NotIn (lambda deps : [torch .int64 , torch .float64 ]),
50+ cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
51+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
52+ cp .Rank .Ge (lambda deps : 1 ),
53+ cp .Size .Ge (lambda deps , r , d : 1 ),
54+ cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
55+ ]
4556 case "sigmoid.default" :
46- additional_tensor_constraints .extend (
57+ tensor_constraints .extend (
4758 [
4859 cp .Dtype .In (lambda deps : [torch .float ]),
4960 cp .Rank .Le (lambda deps : 2 ** 2 ),
@@ -52,7 +63,7 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N
5263 ]
5364 )
5465 case "rsqrt.default" :
55- additional_tensor_constraints .extend (
66+ tensor_constraints .extend (
5667 [
5768 cp .Dtype .In (lambda deps : [torch .float ]),
5869 cp .Rank .Le (lambda deps : 2 ** 2 ),
@@ -63,35 +74,35 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N
6374 ]
6475 )
6576 case "mean.dim" :
66- additional_tensor_constraints .extend (
77+ tensor_constraints .extend (
6778 [
6879 cp .Dtype .In (lambda deps : [torch .float ]),
6980 cp .Rank .Le (lambda deps : 2 ** 2 ),
7081 ]
7182 )
7283 case "exp.default" :
73- additional_tensor_constraints .extend (
84+ tensor_constraints .extend (
7485 [
7586 cp .Rank .Le (lambda deps : 2 ** 3 ),
7687 cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 2 )),
7788 cp .Value .Le (lambda deps , dtype , struct : 2 ** 2 ),
7889 ]
7990 )
8091 case "slice_copy.Tensor" :
81- additional_tensor_constraints .extend (
92+ tensor_constraints .extend (
8293 [
8394 cp .Rank .Le (lambda deps : 2 ),
8495 cp .Value .Ge (lambda deps , dtype , struct : 1 ),
8596 cp .Value .Le (lambda deps , dtype , struct : 2 ),
8697 ]
8798 )
8899 case _:
89- additional_tensor_constraints .extend (
100+ tensor_constraints .extend (
90101 [
91102 cp .Rank .Le (lambda deps : 2 ** 2 ),
92103 ]
93104 )
94- tensor_constraints . extend ( additional_tensor_constraints )
105+ return tensor_constraints
95106
96107
97108def apply_scalar_contraints (op_name : str ) -> list [ScalarDtype ]:
@@ -107,9 +118,6 @@ def apply_scalar_contraints(op_name: str) -> list[ScalarDtype]:
107118def facto_testcase_gen (op_name : str ) -> List [Tuple [List [str ], OrderedDict [str , str ]]]:
108119 # minimal example to test add.Tensor using FACTO
109120 spec = SpecDictDB [op_name ]
110- tensor_constraints = []
111- # common tensor constraints
112- apply_tensor_contraints (op_name , tensor_constraints )
113121
114122 for index , in_spec in enumerate (copy .deepcopy (spec .inspec )):
115123 if in_spec .type .is_scalar ():
@@ -142,7 +150,9 @@ def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, s
142150 ]
143151 )
144152 elif in_spec .type .is_tensor ():
145- spec .inspec [index ].constraints .extend (tensor_constraints )
153+ spec .inspec [index ].constraints .extend (
154+ apply_tensor_contraints (op_name , index )
155+ )
146156 elif in_spec .type .is_dim_list ():
147157 spec .inspec [index ].constraints .extend (
148158 [
0 commit comments