55import copy
66from typing import List , OrderedDict , Tuple
77
8+ import facto .specdb .function as fn
9+
810import torch
911from facto .inputgen .argtuple .gen import ArgumentTupleGenerator
1012from facto .inputgen .specs .model import ConstraintProducer as cp
@@ -22,7 +24,16 @@ def apply_tensor_contraints(op_name: str, tensor_constraints: list[object]) -> N
2224 tensor_constraints .extend (
2325 [
2426 cp .Dtype .In (lambda deps : [torch .float ]),
25- cp .Rank .Le (lambda deps : 2 ** 3 ),
27+ cp .Rank .Le (lambda deps : 2 ** 2 ),
28+ cp .Value .Ge (lambda deps , dtype , struct : - 2 ),
29+ cp .Value .Le (lambda deps , dtype , struct : 2 ),
30+ ]
31+ )
32+ case "mean.dim" :
33+ tensor_constraints .extend (
34+ [
35+ cp .Dtype .In (lambda deps : [torch .float ]),
36+ cp .Rank .Le (lambda deps : 2 ** 2 ),
2637 ]
2738 )
2839 case "exp.default" :
@@ -86,8 +97,27 @@ def facto_testcase_gen(op_name: str) -> List[Tuple[List[str], OrderedDict[str, s
8697 cp .Value .Le (lambda deps , dtype : 2 ),
8798 ]
8899 )
100+ elif in_spec .type .is_scalar_type ():
101+ spec .inspec [index ].constraints .extend (
102+ [
103+ cp .Dtype .In (lambda deps : apply_scalar_contraints (op_name )),
104+ ]
105+ )
89106 elif in_spec .type .is_tensor ():
90107 spec .inspec [index ].constraints .extend (tensor_constraints )
108+ elif in_spec .type .is_dim_list ():
109+ spec .inspec [index ].constraints .extend (
110+ [
111+ cp .Length .Ge (lambda deps : 1 ),
112+ cp .Optional .Eq (lambda deps : False ),
113+ ]
114+ )
115+ elif in_spec .type .is_bool ():
116+ spec .inspec [index ].constraints .extend (
117+ [
118+ cp .Dtype .In (lambda deps : [torch .bool ]),
119+ ]
120+ )
91121
92122 return [
93123 (posargs , inkwargs )
0 commit comments