Skip to content

Commit 2261848

Browse files
authored
fix facto contraints to avoid char,byte in full
Differential Revision: D81705758 Pull Request resolved: #13966
1 parent 364f493 commit 2261848

File tree

1 file changed

+47
-17
lines changed

1 file changed

+47
-17
lines changed

backends/cadence/utils/facto_util.py

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,49 @@
2323

2424

2525
def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
26-
tensor_constraints = [
27-
cp.Dtype.In(
28-
lambda deps: [
29-
torch.int8,
30-
torch.int16,
31-
torch.uint8,
32-
torch.uint16,
33-
torch.float32,
34-
]
35-
),
36-
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
37-
cp.Value.Le(lambda deps, dtype, struct: 2**4),
38-
cp.Rank.Ge(lambda deps: 1),
39-
cp.Size.Ge(lambda deps, r, d: 1),
40-
cp.Size.Le(lambda deps, r, d: 2**9),
41-
cp.Rank.Le(lambda deps: 2**3),
42-
]
26+
tensor_constraints = (
27+
[
28+
cp.Dtype.In(
29+
lambda deps: [
30+
torch.int8,
31+
torch.int16,
32+
torch.uint8,
33+
torch.uint16,
34+
torch.int32,
35+
torch.float32,
36+
]
37+
),
38+
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
39+
cp.Value.Le(lambda deps, dtype, struct: 2**4),
40+
cp.Rank.Ge(lambda deps: 1),
41+
cp.Size.Ge(lambda deps, r, d: 1),
42+
cp.Size.Le(lambda deps, r, d: 2**9),
43+
cp.Rank.Le(lambda deps: 2**3),
44+
]
45+
if op_name
46+
not in (
47+
"slice_copy.Tensor",
48+
"add.Scalar",
49+
"sub.Scalar",
50+
"mul.Scalar",
51+
"div.Tensor",
52+
"neg.default",
53+
)
54+
else [
55+
cp.Dtype.In(
56+
lambda deps: [
57+
torch.int32,
58+
torch.float32,
59+
]
60+
),
61+
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
62+
cp.Value.Le(lambda deps, dtype, struct: 2**4),
63+
cp.Rank.Ge(lambda deps: 1),
64+
cp.Size.Ge(lambda deps, r, d: 1),
65+
cp.Size.Le(lambda deps, r, d: 2**9),
66+
cp.Rank.Le(lambda deps: 2**3),
67+
]
68+
)
4369

4470
match op_name:
4571
case "where.self":
@@ -60,6 +86,7 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
6086
torch.int16,
6187
torch.uint8,
6288
torch.uint16,
89+
torch.int32,
6390
torch.float32,
6491
]
6592
),
@@ -143,6 +170,9 @@ def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
143170
tensor_constraints.extend(
144171
[
145172
cp.Value.Ne(lambda deps, dtype, struct: 0),
173+
cp.Value.Le(lambda deps, dtype, struct: 2**3),
174+
cp.Size.Le(lambda deps, r, d: 2**3),
175+
cp.Rank.Le(lambda deps: 2**2),
146176
]
147177
)
148178
case "div.Tensor_mode" | "minimum.default":

0 commit comments

Comments
 (0)