Skip to content

Commit 3edd217

Browse files
jtuylsftynse
andauthored
[codegen] more consumer fusion (iree-org#21848)
Resolves: iree-org#21814 Resolves: iree-org#21847 Resubmits iree-org#21521 after a revert (iree-org#21819) with an additional fix for the llama 405b compilation issue. Additionally resolves a compilation issue encountered for data-tiled llama3 8b after the revert (iree-org#21847). --------- Signed-off-by: Alex Zinenko <[email protected]> Signed-off-by: Jorn Tuyls <[email protected]> Co-authored-by: Oleksandr "Alex" Zinenko <[email protected]>
1 parent b2edd95 commit 3edd217

File tree

4 files changed

+281
-46
lines changed

4 files changed

+281
-46
lines changed

compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp

Lines changed: 86 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
1010
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
1111
#include "llvm/ADT/STLExtras.h"
12+
#include "llvm/ADT/SmallPtrSet.h"
1213
#include "llvm/Support/Casting.h"
1314
#include "llvm/Support/Debug.h"
1415
#include "mlir/Analysis/TopologicalSortUtils.h"
@@ -87,16 +88,48 @@ void collectTiledAndFusedOps(Operation *rootOp,
8788
}
8889
}
8990

91+
namespace {
92+
// Entry for the pseudo-priority queue of consumer fusion candidates. Contains
93+
// the consumer (fusableUser) that can be fused and the set of slice operations
94+
// in the loop to fuse into that feed the consumer.
95+
struct ConsumerFusionQueueEntry {
96+
ConsumerFusionQueueEntry(SmallVector<Operation *> &&slices,
97+
Operation *fusableUser)
98+
: slices(std::move(slices)), fusableUser(fusableUser) {}
99+
100+
SmallVector<Operation *> slices;
101+
Operation *fusableUser;
102+
};
103+
} // namespace
104+
90105
FailureOr<std::queue<Operation *>>
91-
fuseConsumersIntoForall(RewriterBase &rewriter, Operation *tiledOp,
106+
fuseConsumersIntoForall(RewriterBase &rewriter, ArrayRef<Operation *> tiledOps,
92107
MutableArrayRef<LoopLikeOpInterface> loops,
93108
std::function<bool(Operation *)> filterFn) {
94109
// Collect the candidate slices which can be potential consumers that can be
95-
// fused.
96-
std::queue<SmallVector<Operation *>> candidates;
110+
// fused. Keep them in a vector reverse-sorted by dominance: the candidate
111+
// dominating others comes last (so it can be cheaply popped from the vector).
112+
// The most-dominating candidate is to be fused first since not fusing it may
113+
// prevent dominated candidates to be fused:
114+
//
115+
// A
116+
// |
117+
// B
118+
// / |
119+
// | D
120+
// | /
121+
// C
122+
//
123+
// here, B must be fused before both C and D, and D must be fused before C.
124+
// Candidates are kept in a vector rather than a priority queue since we may
125+
// update them as fusion happens, in particular, more slices may need to be
126+
// handled. For example, fusing B with A will create a slice of B that will
127+
// need to be handled correctly.
128+
SmallVector<ConsumerFusionQueueEntry> candidates;
97129
llvm::SmallDenseSet<tensor::ParallelInsertSliceOp> allCandidates;
98130
auto addCandidateSlices = [&candidates, &allCandidates,
99-
&filterFn](Operation *fusedOp) {
131+
&filterFn](Operation *fusedOp,
132+
DominanceInfo &dominanceInfo) {
100133
for (auto *userOp : fusedOp->getResults().getUsers()) {
101134
auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(userOp);
102135
if (!sliceOp || allCandidates.contains(sliceOp)) {
@@ -113,44 +146,63 @@ fuseConsumersIntoForall(RewriterBase &rewriter, Operation *tiledOp,
113146
continue;
114147
}
115148
mlir::computeTopologicalSorting(users);
116-
117-
Operation *fusableUser = users.front();
118-
// Check all operands from the `scf.forall`
119-
SmallVector<OpResult> loopResults;
120-
for (OpOperand &opOperand : fusableUser->getOpOperands()) {
121-
if (opOperand.get().getDefiningOp() == currLoop.getOperation()) {
122-
loopResults.push_back(cast<OpResult>(opOperand.get()));
149+
for (Operation *fusableUser : users) {
150+
// Check all operands from the `scf.forall`
151+
SmallVector<OpResult> loopResults;
152+
for (OpOperand &opOperand : fusableUser->getOpOperands()) {
153+
if (opOperand.get().getDefiningOp() == currLoop.getOperation()) {
154+
loopResults.push_back(cast<OpResult>(opOperand.get()));
155+
}
123156
}
124-
}
125157

126-
SmallVector<Operation *> fusedSlices;
127-
for (OpResult result : loopResults) {
128-
BlockArgument tiedBlockArg =
129-
currLoop.getTiedBlockArgument(currLoop.getTiedOpOperand(result));
130-
SmallVector<tensor::ParallelInsertSliceOp> slices = llvm::map_to_vector(
131-
currLoop.getCombiningOps(tiedBlockArg), [](Operation *op) {
132-
return cast<tensor::ParallelInsertSliceOp>(op);
133-
});
134-
llvm::append_range(fusedSlices, slices);
135-
allCandidates.insert_range(slices);
136-
}
137-
if (!fusedSlices.empty()) {
138-
candidates.emplace(std::move(fusedSlices));
158+
SmallVector<Operation *> fusedSlices;
159+
for (OpResult result : loopResults) {
160+
BlockArgument tiedBlockArg =
161+
currLoop.getTiedBlockArgument(currLoop.getTiedOpOperand(result));
162+
SmallVector<tensor::ParallelInsertSliceOp> slices =
163+
llvm::map_to_vector(
164+
currLoop.getCombiningOps(tiedBlockArg), [](Operation *op) {
165+
return cast<tensor::ParallelInsertSliceOp>(op);
166+
});
167+
llvm::append_range(fusedSlices, slices);
168+
allCandidates.insert_range(slices);
169+
}
170+
if (!fusedSlices.empty()) {
171+
ConsumerFusionQueueEntry entry(std::move(fusedSlices), fusableUser);
172+
173+
// Comparator that puts the dominating user last.
174+
auto comp = [&](const ConsumerFusionQueueEntry &lhs,
175+
const ConsumerFusionQueueEntry &rhs) {
176+
return dominanceInfo.properlyDominates(rhs.fusableUser,
177+
lhs.fusableUser);
178+
};
179+
180+
// If the fusable user is already a candidate, update it with the new
181+
// list of slices to handle. Otherwise, insert it into the right
182+
// position based on dominance.
183+
auto *it = llvm::lower_bound(candidates, entry, comp);
184+
if (it != candidates.end() && it->fusableUser == fusableUser)
185+
*it = std::move(entry);
186+
else
187+
candidates.insert(it, std::move(entry));
188+
}
139189
}
140190
}
141191
};
142192

143-
addCandidateSlices(tiledOp);
193+
// Add slices from all tiled ops, not only the "main" one.
194+
DominanceInfo dominanceInfo;
195+
for (Operation *tiledOp : tiledOps) {
196+
addCandidateSlices(tiledOp, dominanceInfo);
197+
}
144198

145199
std::queue<Operation *> newFusionOpportunities;
146200
while (!candidates.empty()) {
147-
// Traverse the slices in BFS fashion.
148-
SmallVector<Operation *> candidateSlices = candidates.front();
149-
candidates.pop();
201+
// Get the next candidate.
202+
ConsumerFusionQueueEntry entry = candidates.pop_back_val();
150203

151204
FailureOr<scf::SCFFuseConsumerOfSliceResult> fusedResult =
152-
mlir::scf::tileAndFuseConsumerOfSlices(rewriter, candidateSlices,
153-
loops);
205+
mlir::scf::tileAndFuseConsumerOfSlices(rewriter, entry.slices, loops);
154206
if (failed(fusedResult)) {
155207
return failure();
156208
}
@@ -162,8 +214,10 @@ fuseConsumersIntoForall(RewriterBase &rewriter, Operation *tiledOp,
162214
// The result of the fused consumers might themselves be slices of
163215
// values produced by operations that implement the `TilingInterface`.
164216
// Add these operations to the worklist.
217+
DominanceInfo dominanceInfo;
165218
addCandidateSlices(
166-
fusedResult->tiledAndFusedConsumerOperands.front()->getOwner());
219+
fusedResult->tiledAndFusedConsumerOperands.front()->getOwner(),
220+
dominanceInfo);
167221

168222
// Add the list of new producer fusion opportunities.
169223
for (auto tiledOp : fusedResult.value().tiledOps) {

compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,11 @@ void fuseProducersOfSlices(RewriterBase &rewriter,
3131
void collectTiledAndFusedOps(Operation *rootOp,
3232
llvm::SmallDenseSet<Operation *> &result);
3333

34-
/// Fuse all consumers of the given `tiledOp` into the surrounding `scf.forall`.
35-
/// Returns a list of new `tensor.extract_slice` ops with new fusion
36-
/// opportunities, as well as the new surrounding `scf.forall` (because consumer
37-
/// fusion replaces the loop).
34+
/// Fuse all consumers of the given `tiledOps` into the surrounding `scf.forall`
35+
/// unless specified otherwise by `filterFn`. Returns a list of new
36+
/// `tensor.extract_slice` ops with new fusion opportunities.
3837
FailureOr<std::queue<Operation *>> fuseConsumersIntoForall(
39-
RewriterBase &rewriter, Operation *tiledOp,
38+
RewriterBase &rewriter, ArrayRef<Operation *> tiledOps,
4039
MutableArrayRef<LoopLikeOpInterface> loops,
4140
std::function<bool(Operation *)> filterFn = [](Operation *) {
4241
return true;

compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingForall.cpp

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,18 @@ static bool verifyComputeOpsAfterDistribution(FunctionOpInterface funcOp) {
210210
// Pass implementation.
211211
//===---------------------------------------------------------------------===//
212212

213+
/// Returns true if any value produced by `producer` is used as an init value
214+
/// for the DPS `user`. Returns false if the user is not in DPS.
215+
static bool isUsedAsInit(Operation *producer, Operation *user) {
216+
auto dpsIface = dyn_cast<DestinationStyleOpInterface>(user);
217+
if (!dpsIface)
218+
return false;
219+
ValueRange results = producer->getResults();
220+
return llvm::any_of(dpsIface.getDpsInits(), [&](Value operand) {
221+
return llvm::is_contained(results, operand);
222+
});
223+
}
224+
213225
void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
214226
auto funcOp = getOperation();
215227
auto *context = &getContext();
@@ -232,10 +244,14 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
232244

233245
llvm::DenseSet<Operation *> yieldReplacementsFor;
234246
for (auto op : tiledAndFusedOps) {
235-
// If tiledAndFused ops doesn't contain the user; add an replacement
236-
// for that.
247+
// Require replacement for values that are used after the main tilable op or
248+
// by ops that will definitely not be fused. Note that if a value is used as
249+
// an init of a DPS op, the user currently cannot be fused. Having a
250+
// replacement for it would attempt fusion and fail, so avoid such cases.
237251
if (llvm::any_of(op->getUsers(), [&](Operation *user) {
238-
return dominanceInfo.properlyDominates(tilableOp, user) &&
252+
if (isUsedAsInit(op, user))
253+
return false;
254+
return dominanceInfo.properlyDominates(tilableOp, user) ||
239255
!tiledAndFusedOps.contains(user);
240256
})) {
241257
yieldReplacementsFor.insert(op);
@@ -317,16 +333,18 @@ void TileAndDistributeToWorkgroupsUsingForallOpPass::runOnOperation() {
317333
});
318334
}
319335
std::swap(tileAndFuseResult->loops, tilingLoops);
320-
Operation *rootTiledOp = tileAndFuseResult->tiledAndFusedOps.front();
336+
321337
FailureOr<std::queue<Operation *>> newFusionOpportunities =
322-
fuseConsumersIntoForall(rewriter, rootTiledOp, tilingLoops,
323-
[&tiledAndFusedOps](Operation *op) {
324-
return tiledAndFusedOps.contains(op);
325-
});
338+
fuseConsumersIntoForall(
339+
rewriter, tileAndFuseResult->tiledAndFusedOps.getArrayRef(),
340+
tilingLoops, [&tiledAndFusedOps](Operation *op) {
341+
return tiledAndFusedOps.contains(op);
342+
});
326343
if (failed(newFusionOpportunities)) {
327344
// Continue the work if the failure is allowed.
328345
if (!verifyComputeOpsAfterDistribution(funcOp)) {
329-
rootTiledOp->emitOpError("failed to fuse consumers");
346+
tileAndFuseResult->tiledAndFusedOps.front()->emitOpError(
347+
"failed to fuse consumers");
330348
return signalPassFailure();
331349
}
332350
} else {

0 commit comments

Comments
 (0)