File tree Expand file tree Collapse file tree 1 file changed +24
-0
lines changed
Expand file tree Collapse file tree 1 file changed +24
-0
lines changed Original file line number Diff line number Diff line change @@ -30,6 +30,7 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
3030 torch .int16 ,
3131 torch .uint8 ,
3232 torch .uint16 ,
33+ torch .int32 ,
3334 torch .float32 ,
3435 ]
3536 ),
@@ -42,6 +43,28 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
4243 ]
4344
4445 match op_name :
46+ case (
47+ "slice_copy.Tensor"
48+ | "add.Scalar"
49+ | "sub.Scalar"
50+ | "mul.Scalar"
51+ | "div.Tensor"
52+ | "neg.default"
53+ ):
54+ tensor_constraints = [
55+ cp .Dtype .In (
56+ lambda deps : [
57+ torch .int32 ,
58+ torch .float32 ,
59+ ]
60+ ),
61+ cp .Value .Ge (lambda deps , dtype , struct : - (2 ** 4 )),
62+ cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
63+ cp .Rank .Ge (lambda deps : 1 ),
64+ cp .Size .Ge (lambda deps , r , d : 1 ),
65+ cp .Size .Le (lambda deps , r , d : 2 ** 9 ),
66+ cp .Rank .Le (lambda deps : 2 ** 3 ),
67+ ]
4568 case "where.self" :
4669 if index == 0 : # condition
4770 tensor_constraints = [
@@ -60,6 +83,7 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
6083 torch .int16 ,
6184 torch .uint8 ,
6285 torch .uint16 ,
86+ torch .int32 ,
6387 torch .float32 ,
6488 ]
6589 ),
You can’t perform that action at this time.
0 commit comments