2323
2424
2525def apply_tensor_contraints (op_name : str , index : int ) -> list [object ]:
26- tensor_constraints = [
27- cp .Dtype .In (
28- lambda deps : [
29- torch .int8 ,
30- torch .int16 ,
31- torch .uint8 ,
32- torch .uint16 ,
33- torch .float32 ,
34- ]
35- ),
36- cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
37- cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
38- cp .Rank .Ge (lambda deps : 1 ),
39- cp .Size .Ge (lambda deps , r , d : 1 ),
40- cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
41- cp .Rank .Le (lambda deps : 2 ** 3 ),
42- ]
26+ tensor_constraints = (
27+ [
28+ cp .Dtype .In (
29+ lambda deps : [
30+ torch .int8 ,
31+ torch .int16 ,
32+ torch .uint8 ,
33+ torch .uint16 ,
34+ torch .int32 ,
35+ torch .float32 ,
36+ ]
37+ ),
38+ cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
39+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
40+ cp .Rank .Ge (lambda deps : 1 ),
41+ cp .Size .Ge (lambda deps , r , d : 1 ),
42+ cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
43+ cp .Rank .Le (lambda deps : 2 ** 3 ),
44+ ]
45+ if op_name
46+ not in (
47+ "slice_copy.Tensor" ,
48+ "add.Scalar" ,
49+ "sub.Scalar" ,
50+ "mul.Scalar" ,
51+ "div.Tensor" ,
52+ "neg.default" ,
53+ )
54+ else [
55+ cp .Dtype .In (
56+ lambda deps : [
57+ torch .int32 ,
58+ torch .float32 ,
59+ ]
60+ ),
61+ cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
62+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
63+ cp .Rank .Ge (lambda deps : 1 ),
64+ cp .Size .Ge (lambda deps , r , d : 1 ),
65+ cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
66+ cp .Rank .Le (lambda deps : 2 ** 3 ),
67+ ]
68+ )
4369
4470 match op_name :
4571 case "where.self" :
@@ -60,6 +86,7 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
6086 torch .int16 ,
6187 torch .uint8 ,
6288 torch .uint16 ,
89+ torch .int32 ,
6390 torch .float32 ,
6491 ]
6592 ),
@@ -143,6 +170,9 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
143170 tensor_constraints .extend (
144171 [
145172 cp .Value .Ne (lambda deps , dtype , struct : 0 ),
173+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 3 ),
174+ cp .Size .Le (lambda deps , r , d : 2 ** 3 ),
175+ cp .Rank .Le (lambda deps : 2 ** 2 ),
146176 ]
147177 )
148178 case "div.Tensor_mode" | "minimum.default" :
0 commit comments