Skip to content

Commit e0b184c

Browse files
[mlir][DispatchCreation] Avoid SSA violation due to consumer fusion while forming dispatches (iree-org#21186)
Consumer fusion would create illegal dispatches in presence of diamond fusion patterns when one of the operations arent fused into dispatch. For example ``` %producer = ... %0 = "non_fused_op"(%producer) %1 = "fused_op"(%producer, %0) ``` Moving `"fused_op"` into the same dispatch as `%producer` is a SSA violation. Avoid this fusion. Fixes iree-org#21176 --------- Signed-off-by: MaheshRavishankar <[email protected]> Signed-off-by: Ian Wood <[email protected]> Co-authored-by: Ian Wood <[email protected]>
1 parent 1f1167d commit e0b184c

File tree

4 files changed

+83
-5
lines changed

4 files changed

+83
-5
lines changed

compiler/src/iree/compiler/DispatchCreation/FoldUnitExtentDims.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,10 @@ void FoldUnitExtentDimsPass::runOnOperation() {
259259

260260
RewritePatternSet foldUnitDimsPatterns(context);
261261
populatefoldUnitDimsPatterns(foldUnitDimsPatterns);
262-
if (failed(
263-
applyPatternsGreedily(moduleOp, std::move(foldUnitDimsPatterns)))) {
262+
GreedyRewriteConfig rewriterConfig;
263+
rewriterConfig.setMaxIterations(GreedyRewriteConfig::kNoLimit);
264+
if (failed(applyPatternsGreedily(moduleOp, std::move(foldUnitDimsPatterns),
265+
rewriterConfig))) {
264266
return signalPassFailure();
265267
}
266268
}
@@ -269,8 +271,10 @@ void FoldUnitExtentDimsForFuncPass::runOnOperation() {
269271
MLIRContext *context = &getContext();
270272
RewritePatternSet foldUnitDimsPatterns(context);
271273
populatefoldUnitDimsPatterns(foldUnitDimsPatterns);
272-
if (failed(applyPatternsGreedily(getOperation(),
273-
std::move(foldUnitDimsPatterns)))) {
274+
GreedyRewriteConfig rewriterConfig;
275+
rewriterConfig.setMaxIterations(GreedyRewriteConfig::kNoLimit);
276+
if (failed(applyPatternsGreedily(
277+
getOperation(), std::move(foldUnitDimsPatterns), rewriterConfig))) {
274278
return signalPassFailure();
275279
}
276280
}

compiler/src/iree/compiler/DispatchCreation/FormDispatchRegions.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -994,7 +994,7 @@ createFusionGroups(TensorDimTrackingRewriter &rewriter,
994994
}
995995

996996
if (failed(moveOperandDefs(rewriter, consumer, regionOp, dominanceInfo,
997-
regionOp.getOperation()))) {
997+
{}))) {
998998
continue;
999999
}
10001000

compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,10 @@ LogicalResult moveOperandDefs(RewriterBase &rewriter,
120120
llvm::SetVector<Operation *> slice;
121121
for (auto op : operations) {
122122
for (auto operand : op->getOperands()) {
123+
// If operand is the insertion point, there is nothing to move.
124+
if (operand.getDefiningOp() == insertionPoint) {
125+
continue;
126+
}
123127
[[maybe_unused]] LogicalResult result =
124128
getBackwardSlice(operand, &slice, options);
125129
assert(result.succeeded());
@@ -131,12 +135,20 @@ LogicalResult moveOperandDefs(RewriterBase &rewriter,
131135
llvm::SetVector<Value> capturedVals;
132136
mlir::getUsedValuesDefinedAbove(regions, capturedVals);
133137
for (auto value : capturedVals) {
138+
// If operand is the insertion point, there is nothing to move.
139+
if (value.getDefiningOp() == insertionPoint) {
140+
continue;
141+
}
134142
[[maybe_unused]] LogicalResult result =
135143
getBackwardSlice(value, &slice, options);
136144
assert(result.succeeded());
137145
}
138146
}
139147

148+
if (slice.contains(insertionPoint)) {
149+
return failure();
150+
}
151+
140152
mlir::topologicalSort(slice);
141153
for (auto op : slice) {
142154
rewriter.moveOpBefore(op, insertionPoint);

compiler/src/iree/compiler/DispatchCreation/test/form_dispatch_regions.mlir

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,3 +1331,65 @@ util.func @attention_rope_fusion(%arg0: tensor<10x20x30x50xbf16>,
13311331
// CHECK-SAME: ins(%[[Q]], %[[K]], %[[V]]
13321332
// CHECK: flow.return %[[ATTENTION]]
13331333
// CHECK: util.return %[[DISPATCH]]
1334+
1335+
// -----
1336+
1337+
1338+
// Avoid fusing consumer when the producer/consumer has the following structure
1339+
//
1340+
// ```mlir
1341+
// %producer = "producer_op"
1342+
// %root = "root_op"(%producer)
1343+
// %0 = "non_fusable_op"(%producer)
1344+
// %1 = "consumer_op"(%producer, %root_op, %0)
1345+
// ```
1346+
//
1347+
// Moving the `"producer_op"`, `"root+_op"`, and `"consumer_op"` into a dispatch
1348+
// and leaving `"non_fusable_op"` out would lead to SSA violation.
1349+
util.func public @avoid_illegal_consumer_fusion(%arg0: tensor<75600x5120xf32>) -> tensor<75600x1x5120xbf16> {
1350+
%cst0 = arith.constant 0.0 : bf16
1351+
%0 = tensor.empty() : tensor<75600x5120xbf16>
1352+
%1 = linalg.generic {
1353+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
1354+
iterator_types = ["parallel", "parallel"]}
1355+
ins(%arg0 : tensor<75600x5120xf32>) outs(%0 : tensor<75600x5120xbf16>) {
1356+
^bb0(%in: f32, %out: bf16):
1357+
%13 = arith.truncf %in : f32 to bf16
1358+
linalg.yield %13 : bf16
1359+
} -> tensor<75600x5120xbf16>
1360+
%2 = tensor.empty() : tensor<75600xbf16>
1361+
%3 = linalg.fill ins(%cst0 : bf16) outs(%2 : tensor<75600xbf16>) -> tensor<75600xbf16>
1362+
%4 = linalg.generic {
1363+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
1364+
iterator_types = ["parallel", "reduction"]}
1365+
ins(%1 : tensor<75600x5120xbf16>) outs(%3 : tensor<75600xbf16>) {
1366+
^bb0(%in: bf16, %out: bf16):
1367+
%8 = arith.addf %in, %out : bf16
1368+
linalg.yield %8 : bf16
1369+
} -> tensor<75600xbf16>
1370+
%expanded = tensor.expand_shape %1 [[0], [1, 2]] output_shape [75600, 1, 5120]
1371+
: tensor<75600x5120xbf16> into tensor<75600x1x5120xbf16>
1372+
%5 = tensor.empty() : tensor<75600x1x5120xbf16>
1373+
%6 = linalg.generic {
1374+
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
1375+
affine_map<(d0, d1, d2) -> (d0)>,
1376+
affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
1377+
iterator_types = ["parallel", "parallel", "parallel"]}
1378+
ins(%expanded, %4 : tensor<75600x1x5120xbf16>, tensor<75600xbf16>)
1379+
outs(%5 : tensor<75600x1x5120xbf16>) {
1380+
^bb0(%in: bf16, %in_0: bf16, %out: bf16):
1381+
%9 = arith.subf %in, %in_0 : bf16
1382+
linalg.yield %9 : bf16
1383+
} -> tensor<75600x1x5120xbf16>
1384+
util.return %6 : tensor<75600x1x5120xbf16>
1385+
}
1386+
// CHECK-LABEL: @avoid_illegal_consumer_fusion(
1387+
// CHECK: %[[DISPATCH:.+]]:2 = flow.dispatch.region
1388+
// CHECK: %[[GENERIC0:.+]] = linalg.generic
1389+
// CHECK: %[[GENERIC1:.+]] = linalg.generic
1390+
// CHECK-SAME: ins(%[[GENERIC0]] :
1391+
// CHECK: flow.return %[[GENERIC1]], %[[GENERIC0]]
1392+
// CHECK: %[[EXPAND_SHAPE:.+]] = tensor.expand_shape %[[DISPATCH]]#1
1393+
// CHECK: %[[GENERIC2:.+]] = linalg.generic
1394+
// CHECK-SAME: ins(%[[EXPAND_SHAPE]], %[[DISPATCH]]#0 :
1395+
// CHECK: util.return %[[GENERIC2]]

0 commit comments

Comments
 (0)