23
23
24
24
25
25
def apply_tensor_contraints (op_name : str , index : int ) -> list [object ]:
26
- tensor_constraints = [
27
- cp .Dtype .In (
28
- lambda deps : [
29
- torch .int8 ,
30
- torch .int16 ,
31
- torch .uint8 ,
32
- torch .uint16 ,
33
- torch .float32 ,
34
- ]
35
- ),
36
- cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
37
- cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
38
- cp .Rank .Ge (lambda deps : 1 ),
39
- cp .Size .Ge (lambda deps , r , d : 1 ),
40
- cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
41
- cp .Rank .Le (lambda deps : 2 ** 3 ),
42
- ]
26
+ tensor_constraints = (
27
+ [
28
+ cp .Dtype .In (
29
+ lambda deps : [
30
+ torch .int8 ,
31
+ torch .int16 ,
32
+ torch .uint8 ,
33
+ torch .uint16 ,
34
+ torch .int32 ,
35
+ torch .float32 ,
36
+ ]
37
+ ),
38
+ cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
39
+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
40
+ cp .Rank .Ge (lambda deps : 1 ),
41
+ cp .Size .Ge (lambda deps , r , d : 1 ),
42
+ cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
43
+ cp .Rank .Le (lambda deps : 2 ** 3 ),
44
+ ]
45
+ if op_name
46
+ not in (
47
+ "slice_copy.Tensor" ,
48
+ "add.Scalar" ,
49
+ "sub.Scalar" ,
50
+ "mul.Scalar" ,
51
+ "div.Tensor" ,
52
+ "neg.default" ,
53
+ )
54
+ else [
55
+ cp .Dtype .In (
56
+ lambda deps : [
57
+ torch .int32 ,
58
+ torch .float32 ,
59
+ ]
60
+ ),
61
+ cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
62
+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
63
+ cp .Rank .Ge (lambda deps : 1 ),
64
+ cp .Size .Ge (lambda deps , r , d : 1 ),
65
+ cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
66
+ cp .Rank .Le (lambda deps : 2 ** 3 ),
67
+ ]
68
+ )
43
69
44
70
match op_name :
45
71
case "where.self" :
@@ -60,6 +86,7 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
60
86
torch .int16 ,
61
87
torch .uint8 ,
62
88
torch .uint16 ,
89
+ torch .int32 ,
63
90
torch .float32 ,
64
91
]
65
92
),
@@ -143,6 +170,9 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
143
170
tensor_constraints .extend (
144
171
[
145
172
cp .Value .Ne (lambda deps , dtype , struct : 0 ),
173
+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 3 ),
174
+ cp .Size .Le (lambda deps , r , d : 2 ** 3 ),
175
+ cp .Rank .Le (lambda deps : 2 ** 2 ),
146
176
]
147
177
)
148
178
case "div.Tensor_mode" | "minimum.default" :
0 commit comments