2323
2424
2525def apply_tensor_contraints (op_name : str , index : int ) -> list [object ]:
26+ # Constraint to limit tensor size product to < 4000
27+ max_size_constraint = cp .Size .Le (lambda deps , r , d : max (1 , int ((3999 ) ** (1 / r ))))
28+
2629 tensor_constraints = (
2730 [
2831 cp .Dtype .In (
@@ -39,7 +42,7 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
3942 cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
4043 cp .Rank .Ge (lambda deps : 1 ),
4144 cp .Size .Ge (lambda deps , r , d : 1 ),
42- cp . Size . Le ( lambda deps , r , d : 2 ** 9 ) ,
45+ max_size_constraint ,
4346 cp .Rank .Le (lambda deps : 2 ** 3 ),
4447 ]
4548 if op_name
@@ -62,7 +65,7 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
6265 cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
6366 cp .Rank .Ge (lambda deps : 1 ),
6467 cp .Size .Ge (lambda deps , r , d : 1 ),
65- cp . Size . Le ( lambda deps , r , d : 2 ** 9 ) ,
68+ max_size_constraint ,
6669 cp .Rank .Le (lambda deps : 2 ** 3 ),
6770 ]
6871 )
@@ -76,7 +79,7 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
7679 cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
7780 cp .Rank .Ge (lambda deps : 1 ),
7881 cp .Size .Ge (lambda deps , r , d : 1 ),
79- cp . Size . Le ( lambda deps , r , d : 2 ** 9 ) ,
82+ max_size_constraint ,
8083 ]
8184 else :
8285 tensor_constraints = [
@@ -94,7 +97,7 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
9497 cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
9598 cp .Rank .Ge (lambda deps : 1 ),
9699 cp .Size .Ge (lambda deps , r , d : 1 ),
97- cp . Size . Le ( lambda deps , r , d : 2 ** 9 ) ,
100+ max_size_constraint ,
98101 ]
99102 case "embedding.default" :
100103 tensor_constraints = [
@@ -104,7 +107,7 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
104107 cp .Value .Le (lambda deps , dtype , struct : 2 ** 4 ),
105108 cp .Rank .Ge (lambda deps : 1 ),
106109 cp .Size .Ge (lambda deps , r , d : 1 ),
107- cp . Size . Le ( lambda deps , r , d : 2 ** 9 ) ,
110+ max_size_constraint ,
108111 ]
109112 case "sigmoid.default" :
110113 tensor_constraints .extend (
0 commit comments