Skip to content

Commit c1b7ec5

Browse files
authored
limited facto tensor size to be less than 4000 bytes
Differential Revision: D82476715 Pull Request resolved: #14313
1 parent ed482bd commit c1b7ec5

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

backends/cadence/utils/facto_util.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323

2424

2525
def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
26+
# Constraint to limit tensor size product to < 4000
27+
max_size_constraint = cp.Size.Le(lambda deps, r, d: max(1, int((3999) ** (1 / r))))
28+
2629
tensor_constraints = (
2730
[
2831
cp.Dtype.In(
@@ -39,7 +42,7 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
3942
cp.Value.Le(lambda deps, dtype, struct: 2**4),
4043
cp.Rank.Ge(lambda deps: 1),
4144
cp.Size.Ge(lambda deps, r, d: 1),
42-
cp.Size.Le(lambda deps, r, d: 2**9),
45+
max_size_constraint,
4346
cp.Rank.Le(lambda deps: 2**3),
4447
]
4548
if op_name
@@ -62,7 +65,7 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
6265
cp.Value.Le(lambda deps, dtype, struct: 2**4),
6366
cp.Rank.Ge(lambda deps: 1),
6467
cp.Size.Ge(lambda deps, r, d: 1),
65-
cp.Size.Le(lambda deps, r, d: 2**9),
68+
max_size_constraint,
6669
cp.Rank.Le(lambda deps: 2**3),
6770
]
6871
)
@@ -76,7 +79,7 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
7679
cp.Value.Le(lambda deps, dtype, struct: 2**4),
7780
cp.Rank.Ge(lambda deps: 1),
7881
cp.Size.Ge(lambda deps, r, d: 1),
79-
cp.Size.Le(lambda deps, r, d: 2**9),
82+
max_size_constraint,
8083
]
8184
else:
8285
tensor_constraints = [
@@ -94,7 +97,7 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
9497
cp.Value.Le(lambda deps, dtype, struct: 2**4),
9598
cp.Rank.Ge(lambda deps: 1),
9699
cp.Size.Ge(lambda deps, r, d: 1),
97-
cp.Size.Le(lambda deps, r, d: 2**9),
100+
max_size_constraint,
98101
]
99102
case "embedding.default":
100103
tensor_constraints = [
@@ -104,7 +107,7 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
104107
cp.Value.Le(lambda deps, dtype, struct: 2**4),
105108
cp.Rank.Ge(lambda deps: 1),
106109
cp.Size.Ge(lambda deps, r, d: 1),
107-
cp.Size.Le(lambda deps, r, d: 2**9),
110+
max_size_constraint,
108111
]
109112
case "sigmoid.default":
110113
tensor_constraints.extend(

0 commit comments

Comments
 (0)