Skip to content

Commit 2a5d123

Browse files
authored
Reapply "[DispatchCreation] Extend multi-use producer fusion" (#19032)
Since the upstream changes to `getBackwardSlice` have been integrated (llvm/llvm-project#114452), its now possible to reland #18855. The first commit relands the reverted changes. The second commit uses `BackwardSliceOptions::omitUsesFromAbove` to track all transitive definitions of the possibly fusible op preventing ops being moved before uses. Also, added two tests that check for this issue. Closes #18879 --------- Signed-off-by: Ian Wood <[email protected]>
1 parent 1afe2bc commit 2a5d123

File tree

6 files changed

+215
-74
lines changed

6 files changed

+215
-74
lines changed

.github/workflows/pkgci_regression_test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ jobs:
220220
--goldentime-rocm-unet-ms 419.0 \
221221
--goldentime-rocm-clip-ms 18.5 \
222222
--goldentime-rocm-vae-ms 337.0 \
223-
--goldendispatch-rocm-unet 1531 \
223+
--goldendispatch-rocm-unet 1527 \
224224
--goldendispatch-rocm-clip 1141 \
225225
--goldendispatch-rocm-vae 246 \
226226
--goldensize-rocm-unet-bytes 2280000 \
@@ -241,7 +241,7 @@ jobs:
241241
--goldentime-rocm-unet-ms 95.0 \
242242
--goldentime-rocm-clip-ms 15.5 \
243243
--goldentime-rocm-vae-ms 80.0 \
244-
--goldendispatch-rocm-unet 1531 \
244+
--goldendispatch-rocm-unet 1527 \
245245
--goldendispatch-rocm-clip 1141 \
246246
--goldendispatch-rocm-vae 246 \
247247
--goldensize-rocm-unet-bytes 2270000 \

compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp

Lines changed: 5 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
88
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
99
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
10+
#include "iree/compiler/DispatchCreation/FusionUtils.h"
1011
#include "iree/compiler/DispatchCreation/Passes.h"
1112
#include "mlir/Analysis/SliceAnalysis.h"
1213
#include "mlir/Analysis/TopologicalSortUtils.h"
@@ -107,25 +108,6 @@ static bool isEmptyFillContractionDAGRootOp(
107108
return true;
108109
}
109110

110-
/// Check that a given operation is "horizontal" to the group. The operation
111-
/// is horizontal if the `slice` of the operation does not contain any op
112-
/// from the group.
113-
static bool isHorizontalToGroup(Operation *op,
114-
const llvm::SetVector<Operation *> &currGroup,
115-
const DominanceInfo &dominanceInfo,
116-
Operation *seedOp) {
117-
BackwardSliceOptions options;
118-
// Limit the slice to the seed to make sure the slice is small.
119-
options.filter = [&](Operation *op) {
120-
return !dominanceInfo.properlyDominates(op, seedOp);
121-
};
122-
llvm::SetVector<Operation *> slice;
123-
getBackwardSlice(op, &slice, options);
124-
return !llvm::any_of(currGroup, [&](Operation *groupedOp) {
125-
return slice.contains(groupedOp);
126-
});
127-
}
128-
129111
/// Get user of operation that is a truncate operation.
130112
static std::optional<linalg::GenericOp>
131113
getTruncateOp(Operation *op,
@@ -149,8 +131,8 @@ getTruncateOp(Operation *op,
149131
if (!checkOperationEquivalence(genericOp, seedTruncateOp.value())) {
150132
return std::nullopt;
151133
}
152-
if (!isHorizontalToGroup(genericOp, groupedOperations, dominanceInfo,
153-
seedTruncateOp.value())) {
134+
if (!isHorizontalToGroup(genericOp, groupedOperations.getArrayRef(),
135+
dominanceInfo, seedTruncateOp.value())) {
154136
return std::nullopt;
155137
}
156138
}
@@ -226,7 +208,8 @@ static std::optional<HorizontalFusionGroup> getHorizontalFusionGroupMembers(
226208
if (!dominanceInfo.properlyDominates(seedOp, linalgOp)) {
227209
return false;
228210
}
229-
if (!isHorizontalToGroup(linalgOp, allOps, dominanceInfo, seedOp)) {
211+
if (!isHorizontalToGroup(linalgOp, allOps.getArrayRef(), dominanceInfo,
212+
seedOp)) {
230213
return false;
231214
}
232215
return true;
@@ -346,40 +329,6 @@ static AffineMap getConcatenatedIndexingMap(RewriterBase &rewriter,
346329
return newIndexingMap.insertResult(rewriter.getAffineDimExpr(0), 0);
347330
}
348331

349-
/// During horizontal fusion, there might be operands of the fused operations
350-
/// whose definitions are interspersed between the fused operations. For groups
351-
/// chosen to fuse horizontally, such operations can be moved before the
352-
/// seed contraction operation (where the fused operation is generated).
353-
template <typename T>
354-
static LogicalResult
355-
moveOperandDefs(RewriterBase &rewriter, ArrayRef<T> operations,
356-
Operation *insertionPoint, DominanceInfo &dominanceInfo,
357-
ArrayRef<linalg::LinalgOp> ignoreOperations = {}) {
358-
BackwardSliceOptions options;
359-
llvm::DenseSet<Operation *> ignoreOperationsSet;
360-
ignoreOperationsSet.insert(ignoreOperations.begin(), ignoreOperations.end());
361-
options.filter = [&](Operation *op) {
362-
return !dominanceInfo.properlyDominates(op, insertionPoint) &&
363-
!ignoreOperationsSet.contains(op);
364-
};
365-
// Set inclusive to true cause the slice is computed from the operand, and
366-
// we want to include the defining op (which is the point here)
367-
options.inclusive = true;
368-
369-
llvm::SetVector<Operation *> slice;
370-
for (auto op : operations) {
371-
for (auto operand : op->getOperands()) {
372-
getBackwardSlice(operand, &slice, options);
373-
}
374-
}
375-
376-
mlir::topologicalSort(slice);
377-
for (auto op : slice) {
378-
rewriter.moveOpBefore(op, insertionPoint);
379-
}
380-
return success();
381-
}
382-
383332
/// On finding this pattern
384333
/// ```
385334
/// %0 = linalg.matmul ins(%arg0, %arg1)

compiler/src/iree/compiler/DispatchCreation/FuseMultiUseElementwiseProducer.cpp

Lines changed: 60 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,13 @@
1616
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
1717
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
1818
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
19+
#include "iree/compiler/DispatchCreation/FusionUtils.h"
1920
#include "iree/compiler/DispatchCreation/Passes.h"
21+
#include "llvm/ADT/ArrayRef.h"
22+
#include "llvm/ADT/STLExtras.h"
2023
#include "llvm/Support/CommandLine.h"
2124
#include "llvm/Support/Debug.h"
25+
#include "mlir/Analysis/SliceAnalysis.h"
2226
#include "mlir/Analysis/TopologicalSortUtils.h"
2327
#include "mlir/Dialect/Affine/IR/AffineOps.h"
2428
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -45,25 +49,55 @@ static llvm::cl::opt<int64_t> clLinalgMaxConstantFoldElements(
4549
llvm::cl::desc("Maximum number of elements to try to constant fold."),
4650
llvm::cl::init(0));
4751

52+
static Operation *getMostDominantUse(Operation *op,
53+
const DominanceInfo &dominanceInfo) {
54+
auto uses = op->getUses();
55+
auto it = llvm::find_if(uses, [&](OpOperand &source) {
56+
Operation *sourceOp = source.getOwner();
57+
58+
return llvm::all_of(uses, [&](OpOperand &target) {
59+
Operation *targetOp = target.getOwner();
60+
return dominanceInfo.dominates(sourceOp, targetOp);
61+
});
62+
});
63+
if (it != uses.end()) {
64+
return it->getOwner();
65+
}
66+
return nullptr;
67+
}
68+
4869
/// Check if any of the use dominates all other uses of the operation.
49-
static std::optional<OpOperand *> getFusableUse(Operation *op,
50-
DominanceInfo &dominanceInfo) {
70+
static Operation *getFusableUse(Operation *op,
71+
const DominanceInfo &dominanceInfo) {
5172
auto uses = op->getUses();
73+
Operation *fusableUse = nullptr;
5274
for (OpOperand &source : uses) {
5375
Operation *sourceOp = source.getOwner();
54-
bool dominatesAllUsers = true;
55-
for (OpOperand &target : uses) {
76+
77+
bool dominatesAllFusableOps = llvm::all_of(uses, [&](OpOperand &target) {
5678
Operation *targetOp = target.getOwner();
57-
if (!dominanceInfo.dominates(sourceOp, targetOp)) {
58-
dominatesAllUsers = false;
59-
break;
60-
}
61-
}
62-
if (dominatesAllUsers) {
63-
return &source;
79+
return !isa<linalg::GenericOp>(targetOp) ||
80+
dominanceInfo.dominates(sourceOp, targetOp);
81+
});
82+
if (dominatesAllFusableOps) {
83+
fusableUse = sourceOp;
84+
break;
6485
}
6586
}
66-
return std::nullopt;
87+
Operation *mostDominantOp = getMostDominantUse(op, dominanceInfo);
88+
if (!fusableUse || !mostDominantOp) {
89+
return nullptr;
90+
}
91+
92+
// If `fusableUse` dominates all other users, there's nothing else to do.
93+
if (fusableUse == mostDominantOp) {
94+
return fusableUse;
95+
}
96+
97+
SmallVector<Operation *> users(op->getUsers().begin(), op->getUsers().end());
98+
return isHorizontalToGroup(fusableUse, users, dominanceInfo, mostDominantOp)
99+
? fusableUse
100+
: nullptr;
67101
}
68102

69103
static OpOperand *getFirstUseInConsumer(Operation *producer,
@@ -91,6 +125,7 @@ static SmallVector<OpOperand *> getAllUsesInConsumer(Operation *producer,
91125
/// using elementwise fusion.
92126
static LogicalResult doMultiUseFusion(Operation *rootOp,
93127
llvm::SetVector<Operation *> &fusableOps,
128+
const DominanceInfo &dominanceInfo,
94129
RewriterBase &rewriter) {
95130
assert(rootOp && "root op cant be null");
96131

@@ -112,11 +147,20 @@ static LogicalResult doMultiUseFusion(Operation *rootOp,
112147
Operation *consumerOp = rootOp;
113148
OpBuilder::InsertionGuard g(rewriter);
114149
for (Operation *producerOp : llvm::reverse(fusedOpsVec)) {
150+
Operation *mostDominantUser = getMostDominantUse(producerOp, dominanceInfo);
115151
// Fuse all uses from producer -> consumer. It has been checked
116152
// before that all uses are fusable.
117153
while (OpOperand *fusedOperand =
118154
getFirstUseInConsumer(producerOp, consumerOp)) {
119155
rewriter.setInsertionPoint(consumerOp);
156+
157+
if (consumerOp != mostDominantUser &&
158+
failed(moveOperandDefs(rewriter, ArrayRef<Operation *>{consumerOp},
159+
mostDominantUser, dominanceInfo))) {
160+
return rewriter.notifyMatchFailure(consumerOp,
161+
"failed to move operand defs");
162+
}
163+
rewriter.moveOpBefore(consumerOp, mostDominantUser);
120164
FailureOr<linalg::ElementwiseOpFusionResult> fusionResult =
121165
linalg::fuseElementwiseOps(rewriter, fusedOperand);
122166
if (failed(fusionResult)) {
@@ -190,9 +234,8 @@ static FailureOr<unsigned> fuseMultiUseProducers(Operation *funcOp,
190234
}
191235

192236
// 6. Check that the `genericOp` dominates all uses of `producer`.
193-
std::optional<OpOperand *> fusableUse =
194-
getFusableUse(producer, dominanceInfo);
195-
if (!fusableUse || fusableUse.value()->getOwner() != genericOp) {
237+
Operation *fusableUse = getFusableUse(producer, dominanceInfo);
238+
if (!fusableUse || fusableUse != genericOp) {
196239
continue;
197240
}
198241

@@ -232,7 +275,8 @@ static FailureOr<unsigned> fuseMultiUseProducers(Operation *funcOp,
232275

233276
IRRewriter rewriter(context);
234277
for (auto it = fusedOps.rbegin(), ie = fusedOps.rend(); it != ie; ++it) {
235-
if (failed(doMultiUseFusion(it->first, it->second, rewriter))) {
278+
if (failed(
279+
doMultiUseFusion(it->first, it->second, dominanceInfo, rewriter))) {
236280
return funcOp->emitOpError("failed multi use fusion");
237281
}
238282
}

compiler/src/iree/compiler/DispatchCreation/FusionUtils.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
#include "compiler/src/iree/compiler/DispatchCreation/FusionUtils.h"
1111
#include "compiler/src/iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
1212
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
13+
#include "mlir/Analysis/SliceAnalysis.h"
1314
#include "mlir/Dialect/Linalg/IR/Linalg.h"
15+
#include "mlir/IR/Dominance.h"
16+
#include "mlir/IR/OpDefinition.h"
17+
#include "mlir/Transforms/RegionUtils.h"
1418

1519
namespace mlir::iree_compiler::DispatchCreation {
1620

@@ -97,4 +101,22 @@ bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *fusedOperand,
97101
return true;
98102
}
99103

104+
bool isHorizontalToGroup(Operation *op, ArrayRef<Operation *> currGroup,
105+
const DominanceInfo &dominanceInfo,
106+
Operation *seedOp) {
107+
assert(dominanceInfo.properlyDominates(seedOp, op) &&
108+
op->getParentRegion() == seedOp->getParentRegion());
109+
BackwardSliceOptions options;
110+
options.omitUsesFromAbove = false;
111+
// Limit the slice to the seed to make sure the slice is small.
112+
options.filter = [&](Operation *op) {
113+
return !dominanceInfo.properlyDominates(op, seedOp);
114+
};
115+
llvm::SetVector<Operation *> slice;
116+
getBackwardSlice(op, &slice, options);
117+
return !llvm::any_of(currGroup, [&](Operation *groupedOp) {
118+
return slice.contains(groupedOp);
119+
});
120+
}
121+
100122
} // namespace mlir::iree_compiler::DispatchCreation

compiler/src/iree/compiler/DispatchCreation/FusionUtils.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13+
#include "mlir/Analysis/SliceAnalysis.h"
14+
#include "mlir/Analysis/TopologicalSortUtils.h"
15+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
16+
#include "mlir/IR/Dominance.h"
1317
#include "mlir/IR/Operation.h"
1418

1519
namespace mlir::iree_compiler::DispatchCreation {
@@ -19,4 +23,45 @@ namespace mlir::iree_compiler::DispatchCreation {
1923
bool areFusableAsElementwiseOps(MLIRContext *context, OpOperand *operand,
2024
bool fuseMultiReduction);
2125

26+
/// Check that a given operation is "horizontal" to the group. The operation
27+
/// is horizontal if the program slice of the operation (from op back to seedOp)
28+
/// does not contain any op from the group.
29+
bool isHorizontalToGroup(Operation *op, ArrayRef<Operation *> currGroup,
30+
const DominanceInfo &dominanceInfo, Operation *seedOp);
31+
32+
/// Moves the operands and transitive defs for each op in `operations` directly
33+
/// after `insertionPoint`. Note: this does not check if it is legal to move the
34+
/// operands.
35+
template <typename T>
36+
static LogicalResult
37+
moveOperandDefs(RewriterBase &rewriter, ArrayRef<T> operations,
38+
Operation *insertionPoint, const DominanceInfo &dominanceInfo,
39+
ArrayRef<linalg::LinalgOp> ignoreOperations = {}) {
40+
BackwardSliceOptions options;
41+
options.omitUsesFromAbove = false;
42+
llvm::DenseSet<Operation *> ignoreOperationsSet;
43+
ignoreOperationsSet.insert(ignoreOperations.begin(), ignoreOperations.end());
44+
options.filter = [&](Operation *op) {
45+
return !dominanceInfo.properlyDominates(op, insertionPoint) &&
46+
!ignoreOperationsSet.contains(op);
47+
};
48+
// Set inclusive to true cause the slice is computed from the operand, and
49+
// we want to include the defining op (which is the point here)
50+
options.inclusive = true;
51+
52+
llvm::SetVector<Operation *> slice;
53+
for (auto op : operations) {
54+
assert(insertionPoint->getBlock() == op->getBlock());
55+
for (auto operand : op->getOperands()) {
56+
getBackwardSlice(operand, &slice, options);
57+
}
58+
}
59+
60+
mlir::topologicalSort(slice);
61+
for (auto op : slice) {
62+
rewriter.moveOpBefore(op, insertionPoint);
63+
}
64+
return success();
65+
}
66+
2267
} // namespace mlir::iree_compiler::DispatchCreation

0 commit comments

Comments
 (0)