Skip to content

Commit 32cfabf

Browse files
authored
[Dispatch] Fix error in FuseMultiUseElementwiseProducerPass (#19977)
Changes the logic that finds a fusable consumer to not fuse when there is another use in its body. closes #19947 Signed-off-by: Ian Wood <[email protected]>
1 parent ecfe2b0 commit 32cfabf

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ static std::optional<OpOperand *> getFusableUse(Operation *op,
5454
bool dominatesAllUsers = true;
5555
for (OpOperand &target : uses) {
5656
Operation *targetOp = target.getOwner();
57-
if (!dominanceInfo.dominates(sourceOp, targetOp)) {
57+
if (sourceOp != targetOp &&
58+
!dominanceInfo.properlyDominates(sourceOp, targetOp,
59+
/*enclosingOpOk=*/false)) {
5860
dominatesAllUsers = false;
5961
break;
6062
}

compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,46 @@ util.func public @math_sin() {
139139
// CHECK: %[[GENERIC:.+]]:2 = linalg.generic
140140
// CHECK-DAG: check.expect_almost_eq(%[[GENERIC]]#0,
141141
// CHECK-DAG: check.expect_almost_eq(%[[GENERIC]]#1,
142+
143+
// -----
144+
145+
util.func public @use_in_generic(%arg0 : tensor<1x20x128x2x8xf32>) -> tensor<1x20x128x2x8xf32> {
146+
%cst = arith.constant dense_resource<__elided__> : tensor<128x2x8xf32>
147+
%cst_0 = arith.constant dense_resource<__elided__> : tensor<128x2x8xf32>
148+
%cst_1 = arith.constant 2.500000e-01 : f32
149+
%c0 = arith.constant 0 : index
150+
%c1 = arith.constant 1 : index
151+
%1 = tensor.empty() : tensor<1x20x128x2x8xf32>
152+
%2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%arg0: tensor<1x20x128x2x8xf32>) outs(%1 : tensor<1x20x128x2x8xf32>) {
153+
^bb0(%in: f32, %out: f32):
154+
%6 = arith.mulf %in, %cst_1 : f32
155+
linalg.yield %6 : f32
156+
} -> tensor<1x20x128x2x8xf32>
157+
%3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)>, affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%2, %cst_0, %cst : tensor<1x20x128x2x8xf32>, tensor<128x2x8xf32>, tensor<128x2x8xf32>) outs(%1 : tensor<1x20x128x2x8xf32>) {
158+
^bb0(%in: f32, %in_2: f32, %in_3: f32, %out: f32):
159+
%6 = linalg.index 0 : index
160+
%7 = linalg.index 1 : index
161+
%8 = linalg.index 2 : index
162+
%9 = linalg.index 3 : index
163+
%10 = linalg.index 4 : index
164+
%11 = affine.apply affine_map<()[s0, s1] -> (s0 + s1 * 20)>()[%7, %6]
165+
%12 = arith.subi %c1, %9 : index
166+
%extracted = tensor.extract %2[%c0, %11, %8, %12, %10] : tensor<1x20x128x2x8xf32>
167+
%13 = arith.negf %extracted : f32
168+
%14 = arith.cmpi eq, %12, %c1 : index
169+
%15 = arith.select %14, %13, %extracted : f32
170+
%16 = arith.mulf %15, %in_3 : f32
171+
%17 = arith.mulf %in, %in_2 : f32
172+
%18 = arith.addf %17, %16 : f32
173+
linalg.yield %18 : f32
174+
} -> tensor<1x20x128x2x8xf32>
175+
util.return %3 : tensor<1x20x128x2x8xf32>
176+
}
177+
178+
// These cannot be fused because %2 is an operand of %3 and used in its body.
179+
//
180+
// CHECK-LABEL: util.func public @use_in_generic(
181+
// CHECK: %[[GENERIC0:.+]] = linalg.generic
182+
// CHECK: %[[GENERIC1:.+]] = linalg.generic
183+
// CHECK-SAME: ins(%[[GENERIC0]]
184+
// CHECK: util.return %[[GENERIC1]]

0 commit comments

Comments
 (0)