Skip to content

Commit b0ba380

Browse files
[mlir][Codegen] Remove workaround for handling consumer fusion along multiple operands. (#21171)
With llvm/llvm-project#145193 it is possible to tile and fuse consumers when the consumer uses multiple results of the tiled loop (as long the as the slices of the uses/operands are consistent w.r.t to their use in the consumer). This removes the need for the workaround that was added to handle such cases and generalizes the cases of consumer fusion that can be handled. Fixes #21087 Fixes #21091 Depends on llvm/llvm-project#145193 Signed-off-by: MaheshRavishankar <[email protected]> Signed-off-by: MaheshRavishankar <[email protected]>
1 parent 407d8f2 commit b0ba380

File tree

5 files changed

+166
-177
lines changed

5 files changed

+166
-177
lines changed

compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp

Lines changed: 54 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "llvm/ADT/STLExtras.h"
99
#include "llvm/Support/Casting.h"
1010
#include "llvm/Support/Debug.h"
11+
#include "mlir/Analysis/TopologicalSortUtils.h"
1112
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1213
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1314

@@ -51,36 +52,6 @@ void fuseProducersOfSlices(RewriterBase &rewriter,
5152
}
5253
}
5354

54-
bool warForConsumerFusionSSAViolation(
55-
Operation *rootOp,
56-
const llvm::SmallDenseSet<Operation *> &tiledAndFusedOps) {
57-
auto linalgRootOp = dyn_cast<linalg::LinalgOp>(rootOp);
58-
if (!linalgRootOp) {
59-
return false;
60-
}
61-
SmallVector<utils::IteratorType> iteratorTypes =
62-
linalgRootOp.getIteratorTypesArray();
63-
for (AffineMap map :
64-
llvm::map_range(linalgRootOp.getIndexingMaps(), [](Attribute attr) {
65-
return cast<AffineMapAttr>(attr).getValue();
66-
})) {
67-
if (!compressUnusedDims(map).isIdentity()) {
68-
return false;
69-
}
70-
}
71-
72-
for (OpOperand &use : linalgRootOp->getUses()) {
73-
auto linalgUser = dyn_cast<linalg::LinalgOp>(use.getOwner());
74-
if (!linalgUser) {
75-
return false;
76-
}
77-
if (!linalgUser.getMatchingIndexingMap(&use).isIdentity()) {
78-
return false;
79-
}
80-
}
81-
return true;
82-
}
83-
8455
void collectTiledAndFusedOps(Operation *rootOp,
8556
llvm::SmallDenseSet<Operation *> &result) {
8657
SmallVector<Operation *> worklist;
@@ -111,82 +82,72 @@ void collectTiledAndFusedOps(Operation *rootOp,
11182
}
11283

11384
FailureOr<std::queue<Operation *>>
114-
fuseConsumersIntoLoops(RewriterBase &rewriter, Operation *tiledOp,
115-
MutableArrayRef<LoopLikeOpInterface> loops,
116-
bool useWARForConsumerFusionSSAViolation) {
117-
auto addCandidateSlices = [](Operation *fusedOp,
118-
std::queue<Operation *> &candidates) {
85+
fuseConsumersIntoForall(RewriterBase &rewriter, Operation *tiledOp,
86+
MutableArrayRef<LoopLikeOpInterface> loops,
87+
std::function<bool(Operation *)> filterFn) {
88+
// Collect the candidate slices which can be potential consumers that can be
89+
// fused.
90+
std::queue<SmallVector<Operation *>> candidates;
91+
llvm::SmallDenseSet<tensor::ParallelInsertSliceOp> allCandidates;
92+
auto addCandidateSlices = [&candidates, &allCandidates,
93+
&filterFn](Operation *fusedOp) {
11994
for (auto *userOp : fusedOp->getResults().getUsers()) {
120-
if (llvm::isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
121-
userOp)) {
122-
// Users of tiledOp should either be all of type `tensor.insert_slice`
123-
// or all of`tensor.parallel_insert_slice`.
124-
//
125-
// Pattern 1 - tileing with scf.for:
126-
// %out = scf.for ... {
127-
// %0 = scf.for ... {
128-
// %t0 = op
129-
// %t1 = op %t0 // <- `tiledOp`
130-
// %1 = tensor.insert_slice %t1
131-
// yield %1
132-
// }
133-
// yield %0
134-
// }
135-
//
136-
// Pattern 2 - tiling with scf.forall:
137-
// % out = scf.forall ... {
138-
// %t0 = op
139-
// %t1 = op %t0 // <- `tiledOp`
140-
// scf.forall.in_parallel {
141-
// tensor.parallel_insert_slice %tile
142-
// }
143-
// }
144-
assert((candidates.empty() ||
145-
candidates.front()->getName() == userOp->getName()) &&
146-
"expected all slice users to be of type tensor.insert_slice "
147-
"or of tensor.parallel_insert_slice.");
148-
candidates.push(userOp);
95+
auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(userOp);
96+
if (!sliceOp || allCandidates.contains(sliceOp)) {
97+
continue;
98+
}
99+
100+
auto currLoop =
101+
cast<scf::ForallOp>(sliceOp->getParentOp()->getParentOp());
102+
OpResult loopResult = currLoop.getTiedOpResult(
103+
currLoop.getTiedOpOperand(cast<BlockArgument>(sliceOp.getDest())));
104+
SmallVector<Operation *> users = llvm::to_vector(
105+
llvm::make_filter_range(loopResult.getUsers(), filterFn));
106+
if (users.empty()) {
107+
continue;
108+
}
109+
mlir::computeTopologicalSorting(users);
110+
111+
Operation *fusableUser = users.front();
112+
// Check all operands from the `scf.forall`
113+
SmallVector<OpResult> loopResults;
114+
for (OpOperand &opOperand : fusableUser->getOpOperands()) {
115+
if (opOperand.get().getDefiningOp() == currLoop.getOperation()) {
116+
loopResults.push_back(cast<OpResult>(opOperand.get()));
117+
}
118+
}
119+
120+
SmallVector<Operation *> fusedSlices;
121+
for (OpResult result : loopResults) {
122+
BlockArgument tiedBlockArg =
123+
currLoop.getTiedBlockArgument(currLoop.getTiedOpOperand(result));
124+
SmallVector<tensor::ParallelInsertSliceOp> slices = llvm::map_to_vector(
125+
currLoop.getCombiningOps(tiedBlockArg), [](Operation *op) {
126+
return cast<tensor::ParallelInsertSliceOp>(op);
127+
});
128+
llvm::append_range(fusedSlices, slices);
129+
allCandidates.insert_range(slices);
130+
}
131+
if (!fusedSlices.empty()) {
132+
candidates.emplace(std::move(fusedSlices));
149133
}
150134
}
151135
};
152136

153-
// Collect the candidate slices which can be potential consumers that can be
154-
// fused.
155-
std::queue<Operation *> candidates;
156-
addCandidateSlices(tiledOp, candidates);
137+
addCandidateSlices(tiledOp);
157138

158139
std::queue<Operation *> newFusionOpportunities;
159140
while (!candidates.empty()) {
160141
// Traverse the slices in BFS fashion.
161-
Operation *candidateSliceOp = candidates.front();
142+
SmallVector<Operation *> candidateSlices = candidates.front();
162143
candidates.pop();
163144

164145
FailureOr<scf::SCFFuseConsumerOfSliceResult> fusedResult =
165-
mlir::scf::tileAndFuseConsumerOfSlices(rewriter, candidateSliceOp,
146+
mlir::scf::tileAndFuseConsumerOfSlices(rewriter, candidateSlices,
166147
loops);
167148
if (failed(fusedResult)) {
168-
LLVM_DEBUG(llvm::dbgs() << "failed to fuse consumer of slice: "
169-
<< candidateSliceOp << "\n");
170-
continue;
171-
}
172-
173-
// Implement the WAR for consumer fusion SSA violation (as described in the
174-
// comments for `warForConsumerFusionSSAViolation`)
175-
if (useWARForConsumerFusionSSAViolation) {
176-
for (auto [tiledOpResult, loopResult] :
177-
llvm::zip(tiledOp->getResults(), loops.back()->getResults())) {
178-
for (OpOperand &use : loopResult.getUses()) {
179-
Operation *user = use.getOwner();
180-
if (user->getParentOp() != loops.back()) {
181-
continue;
182-
}
183-
auto slice = dyn_cast<tensor::ExtractSliceOp>(user);
184-
if (!slice) {
185-
return failure();
186-
}
187-
rewriter.replaceAllOpUsesWith(slice, tiledOpResult);
188-
}
189-
}
149+
return candidateSlices.front()->emitOpError(
150+
"failed to fuse consumer of slice");
190151
}
191152

192153
// Replace the original consumer operation with the tiled implementation.
@@ -197,8 +158,7 @@ fuseConsumersIntoLoops(RewriterBase &rewriter, Operation *tiledOp,
197158
// values produced by operations that implement the `TilingInterface`.
198159
// Add these operations to the worklist.
199160
addCandidateSlices(
200-
fusedResult->tiledAndFusedConsumerOperands.front()->getOwner(),
201-
candidates);
161+
fusedResult->tiledAndFusedConsumerOperands.front()->getOwner());
202162

203163
// Add the list of new producer fusion opportunities.
204164
for (auto tiledOp : fusedResult.value().tiledOps) {

compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.h

Lines changed: 12 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -23,83 +23,23 @@ void fuseProducersOfSlices(RewriterBase &rewriter,
2323
scf::SCFTileAndFuseOptions &options,
2424
MutableArrayRef<LoopLikeOpInterface> loops);
2525

26-
/// Consider the following case
27-
///
28-
/// ```mlir
29-
/// %0:2 = linalg.generic {
30-
/// indexing_maps = [....,
31-
/// affine_map<(d0, d1, d2) -> (d0, d1),
32-
/// affine_map<(d0, d1, d2) -> (d0, d1)>]}
33-
/// %1 = linalg.generic ins(%0#0, %0#1) {
34-
/// indexing_maps = [affine_map<(d0, d1) -> (d0, d1),
35-
/// affine_map<(d0, d1) -> (d0, d1)]}
36-
/// ```
37-
///
38-
/// After tiling the first op we get
39-
///
40-
/// ```
41-
/// %0:2 = scf.forall ... {
42-
/// %1:2 = linalg.generic {
43-
/// indexing_maps = [....,
44-
/// affine_map<(d0, d1, d2) -> (d0, d1),
45-
/// affine_map<(d0, d1, d2) -> (d0, d1)>]}
46-
/// }
47-
/// }
48-
/// %2 = linalg.generic ins(%0#0, %0#1) {
49-
/// indexing_maps = [affine_map<(d0, d1) -> (d0, d1),
50-
/// affine_map<(d0, d1) -> (d0, d1)]}
51-
/// ```
52-
///
53-
/// Due to a quirk of the fusion of consumers, fusing this consumer into the
54-
/// loop results in
55-
///
56-
/// ```
57-
/// %0:2 = scf.forall ... {
58-
/// %1:2 = linalg.generic {
59-
/// indexing_maps = [....,
60-
/// affine_map<(d0, d1, d2) -> (d0, d1),
61-
/// affine_map<(d0, d1, d2) -> (d0, d1)>]}
62-
/// %2 = tensor.extract_slice %0#1 [...]
63-
/// %3 = linalg.generic ins(%1#0, %2) {
64-
/// indexing_maps = [affine_map<(d0, d1) -> (d0, d1),
65-
/// affine_map<(d0, d1) -> (d0, d1)]}
66-
/// }
67-
/// }
68-
/// ```
69-
///
70-
/// This is an SSA violation because of `%0#1` being used in the loop. This
71-
/// needs to be fixed upstream, but for cases where
72-
/// 1. The root operation produces results using an identity indexing map (when
73-
/// ignoring the iteration space dimensions corresponding to the reduction
74-
/// loops)
75-
/// 2. For all consumers of the results of the root operation, access the data
76-
/// using identity indexing map then for each consumer fusion step it is valid
77-
/// to replace all uses of slices of the outer loop that occur within the loop
78-
/// with the correponding tiled result value.
79-
/// This is a workaround till upstream transformation can fix this issue. The
80-
/// following method is testing if such a case exists to implement the
81-
/// work-around.
82-
bool warForConsumerFusionSSAViolation(
83-
Operation *rootOp,
84-
const llvm::SmallDenseSet<Operation *> &tiledAndFusedOps);
85-
8626
/// Starting from `op` walk all operands backwards to find all
8727
/// potentially fusible operations, i.e. operations that implement
8828
/// the `TilingInterface`.
8929
void collectTiledAndFusedOps(Operation *rootOp,
9030
llvm::SmallDenseSet<Operation *> &result);
91-
/// Fuses consumers of `tiledOp` into the surrounding `loops`.
92-
///
93-
/// For any previous producer consumer fusion it's expected that `tiledOp` was
94-
/// the consumer into which producers were fused, i.e. `loops` shouldn't contain
95-
/// a consumer of `tiledOp` that isn't an insert_slice op.
96-
/// `fuseConsumersIntoLoops` will fuse consumers of `tiledOp` into surrounding
97-
/// `scf.forall` or `scf.for` loops and return a list of slice ops that expose
98-
/// new fusion opportunities.
99-
FailureOr<std::queue<Operation *>>
100-
fuseConsumersIntoLoops(RewriterBase &rewriter, Operation *tiledOp,
101-
MutableArrayRef<LoopLikeOpInterface> loops,
102-
bool useWARForConsumerFusionSSAViolation);
31+
32+
/// Fuse all consumers of the given `tiledOp` into the surrounding `scf.forall`.
33+
/// Returns a list of new `tensor.extract_slice` ops with new fusion
34+
/// opportunities, as well as the new surrounding `scf.forall` (because consumer
35+
/// fusion replaces the loop).
36+
FailureOr<std::queue<Operation *>> fuseConsumersIntoForall(
37+
RewriterBase &rewriter, Operation *tiledOp,
38+
MutableArrayRef<LoopLikeOpInterface> loops,
39+
std::function<bool(Operation *)> filterFn = [](Operation *) {
40+
return true;
41+
});
42+
10343
} // namespace mlir::iree_compiler
10444

10545
#endif // IREE_COMPILER_CODEGEN_COMMON_TILEANDFUSEUTILS_H_

compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -331,8 +331,6 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
331331
mlir::DominanceInfo dominanceInfo(tilableOp);
332332
llvm::SmallDenseSet<Operation *> tiledAndFusedOps;
333333
collectTiledAndFusedOps(tilableOp, tiledAndFusedOps);
334-
bool useWARForConsumerFusionSSAViolation =
335-
warForConsumerFusionSSAViolation(tilableOp, tiledAndFusedOps);
336334

337335
llvm::DenseSet<Operation *> yieldReplacementsFor;
338336
for (auto op : tiledAndFusedOps) {
@@ -413,13 +411,20 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
413411
return signalPassFailure();
414412
}
415413
for (auto [origValue, replacement] : tileAndFuseResult->replacements) {
416-
rewriter.replaceAllUsesWith(origValue, replacement);
414+
Value replacementCopy = replacement;
415+
rewriter.replaceUsesWithIf(origValue, replacement, [&](OpOperand &use) {
416+
Operation *user = use.getOwner();
417+
return !isa<tensor::DimOp>(user) &&
418+
dominanceInfo.dominates(replacementCopy, user);
419+
});
417420
}
418421
std::swap(tileAndFuseResult->loops, tilingLoops);
419422
Operation *rootTiledOp = tileAndFuseResult->tiledAndFusedOps.front();
420423
FailureOr<std::queue<Operation *>> newFusionOpportunities =
421-
fuseConsumersIntoLoops(rewriter, rootTiledOp, tilingLoops,
422-
useWARForConsumerFusionSSAViolation);
424+
fuseConsumersIntoForall(rewriter, rootTiledOp, tilingLoops,
425+
[&tiledAndFusedOps](Operation *op) {
426+
return tiledAndFusedOps.contains(op);
427+
});
423428
if (failed(newFusionOpportunities)) {
424429
rootTiledOp->emitOpError("failed to fuse consumers");
425430
return signalPassFailure();

0 commit comments

Comments
 (0)