Skip to content

Commit 8806173

Browse files
authored
Revert "[DispatchCreation] Extend multi-use producer fusion" (#18917)
The reverted commit does not handle when the "consumer" uses a value defined above. See #18879 for the original issue. This is causing issue with ~15 onnx models. I have a PR (#18855) to fix this by including values used in an ops region in the backwards slice, but It is waiting on upstream changes to `getBackwardSlice`. Currently, the PR is using a wrapper around `getBackwardSlice` to acheive the same effect, but this will be updated once the upstream change lands (llvm/llvm-project#113478) Reverts #18551 --------- Signed-off-by: Ian Wood <[email protected]>
1 parent f8b8414 commit 8806173

File tree

6 files changed

+74
-169
lines changed

6 files changed

+74
-169
lines changed

.github/workflows/pkgci_regression_test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ jobs:
220220
--goldentime-rocm-unet-ms 419.0 \
221221
--goldentime-rocm-clip-ms 18.5 \
222222
--goldentime-rocm-vae-ms 337.0 \
223-
--goldendispatch-rocm-unet 1527 \
223+
--goldendispatch-rocm-unet 1531 \
224224
--goldendispatch-rocm-clip 1139 \
225225
--goldendispatch-rocm-vae 247 \
226226
--goldensize-rocm-unet-bytes 2280000 \
@@ -241,7 +241,7 @@ jobs:
241241
--goldentime-rocm-unet-ms 95.0 \
242242
--goldentime-rocm-clip-ms 15.5 \
243243
--goldentime-rocm-vae-ms 80.0 \
244-
--goldendispatch-rocm-unet 1527 \
244+
--goldendispatch-rocm-unet 1531 \
245245
--goldendispatch-rocm-clip 1139 \
246246
--goldendispatch-rocm-vae 247 \
247247
--goldensize-rocm-unet-bytes 2270000 \

compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
88
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
99
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
10-
#include "iree/compiler/DispatchCreation/FusionUtils.h"
1110
#include "iree/compiler/DispatchCreation/Passes.h"
1211
#include "mlir/Analysis/SliceAnalysis.h"
1312
#include "mlir/Analysis/TopologicalSortUtils.h"
@@ -108,6 +107,25 @@ static bool isEmptyFillContractionDAGRootOp(
108107
return true;
109108
}
110109

110+
/// Check that a given operation is "horizontal" to the group. The operation
111+
/// is horizontal if the `slice` of the operation does not contain any op
112+
/// from the group.
113+
static bool isHorizontalToGroup(Operation *op,
114+
const llvm::SetVector<Operation *> &currGroup,
115+
const DominanceInfo &dominanceInfo,
116+
Operation *seedOp) {
117+
BackwardSliceOptions options;
118+
// Limit the slice to the seed to make sure the slice is small.
119+
options.filter = [&](Operation *op) {
120+
return !dominanceInfo.properlyDominates(op, seedOp);
121+
};
122+
llvm::SetVector<Operation *> slice;
123+
getBackwardSlice(op, &slice, options);
124+
return !llvm::any_of(currGroup, [&](Operation *groupedOp) {
125+
return slice.contains(groupedOp);
126+
});
127+
}
128+
111129
/// Get user of operation that is a truncate operation.
112130
static std::optional<linalg::GenericOp>
113131
getTruncateOp(Operation *op,
@@ -131,8 +149,8 @@ getTruncateOp(Operation *op,
131149
if (!checkOperationEquivalence(genericOp, seedTruncateOp.value())) {
132150
return std::nullopt;
133151
}
134-
if (!isHorizontalToGroup(genericOp, groupedOperations.getArrayRef(),
135-
dominanceInfo, seedTruncateOp.value())) {
152+
if (!isHorizontalToGroup(genericOp, groupedOperations, dominanceInfo,
153+
seedTruncateOp.value())) {
136154
return std::nullopt;
137155
}
138156
}
@@ -208,8 +226,7 @@ static std::optional<HorizontalFusionGroup> getHorizontalFusionGroupMembers(
208226
if (!dominanceInfo.properlyDominates(seedOp, linalgOp)) {
209227
return false;
210228
}
211-
if (!isHorizontalToGroup(linalgOp, allOps.getArrayRef(), dominanceInfo,
212-
seedOp)) {
229+
if (!isHorizontalToGroup(linalgOp, allOps, dominanceInfo, seedOp)) {
213230
return false;
214231
}
215232
return true;
@@ -329,6 +346,40 @@ static AffineMap getConcatenatedIndexingMap(RewriterBase &rewriter,
329346
return newIndexingMap.insertResult(rewriter.getAffineDimExpr(0), 0);
330347
}
331348

349+
/// During horizontal fusion, there might be operands of the fused operations
350+
/// whose definitions are interspersed between the fused operations. For groups
351+
/// chosen to fuse horizontally, such operations can be moved before the
352+
/// seed contraction operation (where the fused operation is generated).
353+
template <typename T>
354+
static LogicalResult
355+
moveOperandDefs(RewriterBase &rewriter, ArrayRef<T> operations,
356+
Operation *insertionPoint, DominanceInfo &dominanceInfo,
357+
ArrayRef<linalg::LinalgOp> ignoreOperations = {}) {
358+
BackwardSliceOptions options;
359+
llvm::DenseSet<Operation *> ignoreOperationsSet;
360+
ignoreOperationsSet.insert(ignoreOperations.begin(), ignoreOperations.end());
361+
options.filter = [&](Operation *op) {
362+
return !dominanceInfo.properlyDominates(op, insertionPoint) &&
363+
!ignoreOperationsSet.contains(op);
364+
};
365+
// Set inclusive to true cause the slice is computed from the operand, and
366+
// we want to include the defining op (which is the point here)
367+
options.inclusive = true;
368+
369+
llvm::SetVector<Operation *> slice;
370+
for (auto op : operations) {
371+
for (auto operand : op->getOperands()) {
372+
getBackwardSlice(operand, &slice, options);
373+
}
374+
}
375+
376+
mlir::topologicalSort(slice);
377+
for (auto op : slice) {
378+
rewriter.moveOpBefore(op, insertionPoint);
379+
}
380+
return success();
381+
}
382+
332383
/// On finding this pattern
333384
/// ```
334385
/// %0 = linalg.matmul ins(%arg0, %arg1)

compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp

Lines changed: 16 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,9 @@
1616
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
1717
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
1818
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
19-
#include "iree/compiler/DispatchCreation/FusionUtils.h"
2019
#include "iree/compiler/DispatchCreation/Passes.h"
21-
#include "llvm/ADT/ArrayRef.h"
22-
#include "llvm/ADT/STLExtras.h"
2320
#include "llvm/Support/CommandLine.h"
2421
#include "llvm/Support/Debug.h"
25-
#include "mlir/Analysis/SliceAnalysis.h"
2622
#include "mlir/Analysis/TopologicalSortUtils.h"
2723
#include "mlir/Dialect/Affine/IR/AffineOps.h"
2824
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -49,55 +45,25 @@ static llvm::cl::opt<int64_t> clLinalgMaxConstantFoldElements(
4945
llvm::cl::desc("Maximum number of elements to try to constant fold."),
5046
llvm::cl::init(0));
5147

52-
static Operation *getMostDominantUse(Operation *op,
53-
const DominanceInfo &dominanceInfo) {
54-
auto uses = op->getUses();
55-
auto it = llvm::find_if(uses, [&](OpOperand &source) {
56-
Operation *sourceOp = source.getOwner();
57-
58-
return llvm::all_of(uses, [&](OpOperand &target) {
59-
Operation *targetOp = target.getOwner();
60-
return dominanceInfo.dominates(sourceOp, targetOp);
61-
});
62-
});
63-
if (it != uses.end()) {
64-
return it->getOwner();
65-
}
66-
return nullptr;
67-
}
68-
6948
/// Check if any of the use dominates all other uses of the operation.
70-
static Operation *getFusableUse(Operation *op,
71-
const DominanceInfo &dominanceInfo) {
49+
static std::optional<OpOperand *> getFusableUse(Operation *op,
50+
DominanceInfo &dominanceInfo) {
7251
auto uses = op->getUses();
73-
Operation *fusableUse = nullptr;
7452
for (OpOperand &source : uses) {
7553
Operation *sourceOp = source.getOwner();
76-
77-
bool dominatesAllFusableOps = llvm::all_of(uses, [&](OpOperand &target) {
54+
bool dominatesAllUsers = true;
55+
for (OpOperand &target : uses) {
7856
Operation *targetOp = target.getOwner();
79-
return !isa<linalg::GenericOp>(targetOp) ||
80-
dominanceInfo.dominates(sourceOp, targetOp);
81-
});
82-
if (dominatesAllFusableOps) {
83-
fusableUse = sourceOp;
84-
break;
57+
if (!dominanceInfo.dominates(sourceOp, targetOp)) {
58+
dominatesAllUsers = false;
59+
break;
60+
}
61+
}
62+
if (dominatesAllUsers) {
63+
return &source;
8564
}
8665
}
87-
Operation *mostDominantOp = getMostDominantUse(op, dominanceInfo);
88-
if (!fusableUse || !mostDominantOp) {
89-
return nullptr;
90-
}
91-
92-
// If `fusableUse` dominates all other users, there's nothing else to do.
93-
if (fusableUse == mostDominantOp) {
94-
return fusableUse;
95-
}
96-
97-
SmallVector<Operation *> users(op->getUsers().begin(), op->getUsers().end());
98-
return isHorizontalToGroup(fusableUse, users, dominanceInfo, mostDominantOp)
99-
? fusableUse
100-
: nullptr;
66+
return std::nullopt;
10167
}
10268

10369
static OpOperand *getFirstUseInConsumer(Operation *producer,
@@ -125,7 +91,6 @@ static SmallVector<OpOperand *> getAllUsesInConsumer(Operation *producer,
12591
/// using elementwise fusion.
12692
static LogicalResult doMultiUseFusion(Operation *rootOp,
12793
llvm::SetVector<Operation *> &fusableOps,
128-
const DominanceInfo &dominanceInfo,
12994
RewriterBase &rewriter) {
13095
assert(rootOp && "root op cant be null");
13196

@@ -147,20 +112,11 @@ static LogicalResult doMultiUseFusion(Operation *rootOp,
147112
Operation *consumerOp = rootOp;
148113
OpBuilder::InsertionGuard g(rewriter);
149114
for (Operation *producerOp : llvm::reverse(fusedOpsVec)) {
150-
Operation *mostDominantUser = getMostDominantUse(producerOp, dominanceInfo);
151115
// Fuse all uses from producer -> consumer. It has been checked
152116
// before that all uses are fusable.
153117
while (OpOperand *fusedOperand =
154118
getFirstUseInConsumer(producerOp, consumerOp)) {
155119
rewriter.setInsertionPoint(consumerOp);
156-
157-
if (consumerOp != mostDominantUser &&
158-
failed(moveOperandDefs(rewriter, ArrayRef<Operation *>{consumerOp},
159-
mostDominantUser, dominanceInfo))) {
160-
return rewriter.notifyMatchFailure(consumerOp,
161-
"failed to move operand defs");
162-
}
163-
rewriter.moveOpBefore(consumerOp, mostDominantUser);
164120
FailureOr<linalg::ElementwiseOpFusionResult> fusionResult =
165121
linalg::fuseElementwiseOps(rewriter, fusedOperand);
166122
if (failed(fusionResult)) {
@@ -234,8 +190,9 @@ static FailureOr<unsigned> fuseMultiUseProducers(Operation *funcOp,
234190
}
235191

236192
// 6. Check that the `genericOp` dominates all uses of `producer`.
237-
Operation *fusableUse = getFusableUse(producer, dominanceInfo);
238-
if (!fusableUse || fusableUse != genericOp) {
193+
std::optional<OpOperand *> fusableUse =
194+
getFusableUse(producer, dominanceInfo);
195+
if (!fusableUse || fusableUse.value()->getOwner() != genericOp) {
239196
continue;
240197
}
241198

@@ -275,8 +232,7 @@ static FailureOr<unsigned> fuseMultiUseProducers(Operation *funcOp,
275232

276233
IRRewriter rewriter(context);
277234
for (auto it = fusedOps.rbegin(), ie = fusedOps.rend(); it != ie; ++it) {
278-
if (failed(
279-
doMultiUseFusion(it->first, it->second, dominanceInfo, rewriter))) {
235+
if (failed(doMultiUseFusion(it->first, it->second, rewriter))) {
280236
return funcOp->emitOpError("failed multi use fusion");
281237
}
282238
}

compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,7 @@
1010
#include "compiler/src/iree/compiler/DispatchCreation/FusionUtils.h"
1111
#include "compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
1212
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
13-
#include "mlir/Analysis/SliceAnalysis.h"
1413
#include "mlir/Dialect/Linalg/IR/Linalg.h"
15-
#include "mlir/IR/Dominance.h"
16-
#include "mlir/IR/OpDefinition.h"
17-
#include "mlir/Transforms/RegionUtils.h"
1814

1915
namespace mlir::iree_compiler::DispatchCreation {
2016

@@ -101,33 +97,4 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand,
10197
return true;
10298
}
10399

104-
bool isHorizontalToGroup(Operation *op, ArrayRef<Operation *> currGroup,
105-
const DominanceInfo &dominanceInfo,
106-
Operation *seedOp) {
107-
assert(dominanceInfo.properlyDominates(seedOp, op) &&
108-
op->getParentRegion() == seedOp->getParentRegion());
109-
BackwardSliceOptions options;
110-
// Limit the slice to the seed to make sure the slice is small.
111-
options.filter = [&](Operation *op) {
112-
return !dominanceInfo.properlyDominates(op, seedOp);
113-
};
114-
llvm::SetVector<Operation *> slice;
115-
getBackwardSlice(op, &slice, options);
116-
117-
// `getBackwardSlice` doesnt track uses from within an ops region, so make
118-
// sure there are no values defined above.
119-
for (Operation *sliceOp : slice) {
120-
bool usesValuesFromAbove = false;
121-
mlir::visitUsedValuesDefinedAbove(
122-
sliceOp->getRegions(), [&](void *) { usesValuesFromAbove = true; });
123-
if (usesValuesFromAbove) {
124-
return false;
125-
}
126-
}
127-
128-
return !llvm::any_of(currGroup, [&](Operation *groupedOp) {
129-
return slice.contains(groupedOp);
130-
});
131-
}
132-
133100
} // namespace mlir::iree_compiler::DispatchCreation

compiler/src/iree/compiler/DispatchCreation/FusionUtils.h

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,6 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13-
#include "mlir/Analysis/SliceAnalysis.h"
14-
#include "mlir/Analysis/TopologicalSortUtils.h"
15-
#include "mlir/Dialect/Linalg/IR/Linalg.h"
16-
#include "mlir/IR/Dominance.h"
1713
#include "mlir/IR/Operation.h"
1814

1915
namespace mlir::iree_compiler::DispatchCreation {
@@ -23,44 +19,4 @@ namespace mlir::iree_compiler::DispatchCreation {
2319
bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *operand,
2420
bool fuseMultiReduction);
2521

26-
/// Check that a given operation is "horizontal" to the group. The operation
27-
/// is horizontal if the program slice of the operation (from op back to seedOp)
28-
/// does not contain any op from the group.
29-
bool isHorizontalToGroup(Operation *op, ArrayRef<Operation *> currGroup,
30-
const DominanceInfo &dominanceInfo, Operation *seedOp);
31-
32-
/// Moves the operands and transitive defs for each op in `operations` directly
33-
/// after `insertionPoint`. Note: this does not check if it is legal to move the
34-
/// operands.
35-
template <typename T>
36-
static LogicalResult
37-
moveOperandDefs(RewriterBase &rewriter, ArrayRef<T> operations,
38-
Operation *insertionPoint, const DominanceInfo &dominanceInfo,
39-
ArrayRef<linalg::LinalgOp> ignoreOperations = {}) {
40-
BackwardSliceOptions options;
41-
llvm::DenseSet<Operation *> ignoreOperationsSet;
42-
ignoreOperationsSet.insert(ignoreOperations.begin(), ignoreOperations.end());
43-
options.filter = [&](Operation *op) {
44-
return !dominanceInfo.properlyDominates(op, insertionPoint) &&
45-
!ignoreOperationsSet.contains(op);
46-
};
47-
// Set inclusive to true cause the slice is computed from the operand, and
48-
// we want to include the defining op (which is the point here)
49-
options.inclusive = true;
50-
51-
llvm::SetVector<Operation *> slice;
52-
for (auto op : operations) {
53-
assert(insertionPoint->getBlock() == op->getBlock());
54-
for (auto operand : op->getOperands()) {
55-
getBackwardSlice(operand, &slice, options);
56-
}
57-
}
58-
59-
mlir::topologicalSort(slice);
60-
for (auto op : slice) {
61-
rewriter.moveOpBefore(op, insertionPoint);
62-
}
63-
return success();
64-
}
65-
6622
} // namespace mlir::iree_compiler::DispatchCreation

compiler/src/iree/compiler/DispatchCreation/test/fuse_multiuse_elementwise_producer.mlir

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -139,28 +139,3 @@ util.func public @math_sin() {
139139
// CHECK: %[[GENERIC:.+]]:2 = linalg.generic
140140
// CHECK-DAG: check.expect_almost_eq(%[[GENERIC]]#0,
141141
// CHECK-DAG: check.expect_almost_eq(%[[GENERIC]]#1,
142-
143-
// -----
144-
145-
#map = affine_map<(d0, d1) -> (d0, d1)>
146-
util.func public @fuse_by_moving_consumer(%arg0: tensor<5x5xf32>, %arg1: tensor<5x5xf32>) -> (tensor<5x5xf32>, tensor<25xf32>) {
147-
%cst = arith.constant 1.000000e+00 : f32
148-
%cst_0 = arith.constant 2.000000e+00 : f32
149-
%cst_1 = arith.constant 3.000000e+00 : f32
150-
%4 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
151-
^bb0(%arg2: f32, %arg3: f32):
152-
%8 = arith.addf %arg2, %cst : f32
153-
linalg.yield %8 : f32
154-
} -> tensor<5x5xf32>
155-
// expected-note @below {{prior use here}}
156-
%collapsed = tensor.collapse_shape %4 [[0, 1]] : tensor<5x5xf32> into tensor<25xf32>
157-
%5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%4 : tensor<5x5xf32>) outs(%arg1 : tensor<5x5xf32>) {
158-
^bb0(%arg2: f32, %arg3: f32):
159-
%8 = arith.subf %arg2, %cst_0 : f32
160-
linalg.yield %8 : f32
161-
} -> tensor<5x5xf32>
162-
util.return %5, %collapsed: tensor<5x5xf32>, tensor<25xf32>
163-
}
164-
// CHECK-LABEL: util.func public @fuse_by_moving_consumer
165-
// CHECK: linalg.generic
166-
// CHECK-NOT: linalg.generic

0 commit comments

Comments
 (0)