Skip to content

Commit e1ea74f

Browse files
authored
Enforce tensor a dtype == tensor b dtype for where.out in facto
Differential Revision: D82577515 Pull Request resolved: #14352
1 parent ed179c0 commit e1ea74f

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

backends/cadence/utils/facto_util.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,25 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
167167
cp.Size.Ge(lambda deps, r, d: 1),
168168
max_size_constraint,
169169
]
170-
else:
170+
elif index == 1: # input tensor(a)
171+
tensor_constraints = [
172+
cp.Dtype.In(
173+
lambda deps: [
174+
torch.int8,
175+
torch.int16,
176+
torch.uint8,
177+
torch.uint16,
178+
torch.int32,
179+
torch.float32,
180+
]
181+
),
182+
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
183+
cp.Value.Le(lambda deps, dtype, struct: 2**4),
184+
cp.Rank.Ge(lambda deps: 1),
185+
cp.Size.Ge(lambda deps, r, d: 1),
186+
max_size_constraint,
187+
]
188+
else: # input tensor(b)
171189
tensor_constraints = [
172190
cp.Dtype.In(
173191
lambda deps: [
@@ -179,6 +197,7 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
179197
torch.float32,
180198
]
181199
),
200+
cp.Dtype.Eq(lambda deps: deps[1].dtype),
182201
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
183202
cp.Value.Le(lambda deps, dtype, struct: 2**4),
184203
cp.Rank.Ge(lambda deps: 1),

0 commit comments

Comments
 (0)