99#include " iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
1010#include " iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
1111#include " llvm/ADT/STLExtras.h"
12+ #include " llvm/ADT/SmallPtrSet.h"
1213#include " llvm/Support/Casting.h"
1314#include " llvm/Support/Debug.h"
1415#include " mlir/Analysis/TopologicalSortUtils.h"
@@ -87,16 +88,48 @@ void collectTiledAndFusedOps(Operation *rootOp,
8788 }
8889}
8990
91+ namespace {
92+ // Entry for the pseudo-priority queue of consumer fusion candidates. Contains
93+ // the consumer (fusableUser) that can be fused and the set of slice operations
94+ // in the loop to fuse into that feed the consumer.
95+ struct ConsumerFusionQueueEntry {
96+ ConsumerFusionQueueEntry (SmallVector<Operation *> &&slices,
97+ Operation *fusableUser)
98+ : slices(std::move(slices)), fusableUser(fusableUser) {}
99+
100+ SmallVector<Operation *> slices;
101+ Operation *fusableUser;
102+ };
103+ } // namespace
104+
90105FailureOr<std::queue<Operation *>>
91- fuseConsumersIntoForall (RewriterBase &rewriter, Operation *tiledOp ,
106+ fuseConsumersIntoForall (RewriterBase &rewriter, ArrayRef< Operation *> tiledOps ,
92107 MutableArrayRef<LoopLikeOpInterface> loops,
93108 std::function<bool (Operation *)> filterFn) {
94109 // Collect the candidate slices which can be potential consumers that can be
95- // fused.
96- std::queue<SmallVector<Operation *>> candidates;
110+ // fused. Keep them in a vector reverse-sorted by dominance: the candidate
111+ // dominating others comes last (so it can be cheaply popped from the vector).
112+ // The most-dominating candidate is to be fused first since not fusing it may
113+ // prevent dominated candidates to be fused:
114+ //
115+ // A
116+ // |
117+ // B
118+ // / |
119+ // | D
120+ // | /
121+ // C
122+ //
123+ // here, B must be fused before both C and D, and D must be fused before C.
124+ // Candidates are kept in a vector rather than a priority queue since we may
125+ // update them as fusion happens, in particular, more slices may need to be
126+ // handled. For example, fusing B with A will create a slice of B that will
127+ // need to be handled correctly.
128+ SmallVector<ConsumerFusionQueueEntry> candidates;
97129 llvm::SmallDenseSet<tensor::ParallelInsertSliceOp> allCandidates;
98130 auto addCandidateSlices = [&candidates, &allCandidates,
99- &filterFn](Operation *fusedOp) {
131+ &filterFn](Operation *fusedOp,
132+ DominanceInfo &dominanceInfo) {
100133 for (auto *userOp : fusedOp->getResults ().getUsers ()) {
101134 auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(userOp);
102135 if (!sliceOp || allCandidates.contains (sliceOp)) {
@@ -113,44 +146,63 @@ fuseConsumersIntoForall(RewriterBase &rewriter, Operation *tiledOp,
113146 continue ;
114147 }
115148 mlir::computeTopologicalSorting (users);
116-
117- Operation *fusableUser = users. front ();
118- // Check all operands from the `scf.forall`
119- SmallVector<OpResult> loopResults;
120- for (OpOperand & opOperand : fusableUser-> getOpOperands ()) {
121- if ( opOperand.get (). getDefiningOp () == currLoop. getOperation ()) {
122- loopResults. push_back (cast<OpResult>(opOperand. get ()));
149+ for (Operation *fusableUser : users) {
150+ // Check all operands from the `scf.forall`
151+ SmallVector<OpResult> loopResults;
152+ for (OpOperand &opOperand : fusableUser-> getOpOperands ()) {
153+ if ( opOperand. get (). getDefiningOp () == currLoop. getOperation ()) {
154+ loopResults. push_back (cast<OpResult>( opOperand.get ()));
155+ }
123156 }
124- }
125157
126- SmallVector<Operation *> fusedSlices;
127- for (OpResult result : loopResults) {
128- BlockArgument tiedBlockArg =
129- currLoop.getTiedBlockArgument (currLoop.getTiedOpOperand (result));
130- SmallVector<tensor::ParallelInsertSliceOp> slices = llvm::map_to_vector (
131- currLoop.getCombiningOps (tiedBlockArg), [](Operation *op) {
132- return cast<tensor::ParallelInsertSliceOp>(op);
133- });
134- llvm::append_range (fusedSlices, slices);
135- allCandidates.insert_range (slices);
136- }
137- if (!fusedSlices.empty ()) {
138- candidates.emplace (std::move (fusedSlices));
158+ SmallVector<Operation *> fusedSlices;
159+ for (OpResult result : loopResults) {
160+ BlockArgument tiedBlockArg =
161+ currLoop.getTiedBlockArgument (currLoop.getTiedOpOperand (result));
162+ SmallVector<tensor::ParallelInsertSliceOp> slices =
163+ llvm::map_to_vector (
164+ currLoop.getCombiningOps (tiedBlockArg), [](Operation *op) {
165+ return cast<tensor::ParallelInsertSliceOp>(op);
166+ });
167+ llvm::append_range (fusedSlices, slices);
168+ allCandidates.insert_range (slices);
169+ }
170+ if (!fusedSlices.empty ()) {
171+ ConsumerFusionQueueEntry entry (std::move (fusedSlices), fusableUser);
172+
173+ // Comparator that puts the dominating user last.
174+ auto comp = [&](const ConsumerFusionQueueEntry &lhs,
175+ const ConsumerFusionQueueEntry &rhs) {
176+ return dominanceInfo.properlyDominates (rhs.fusableUser ,
177+ lhs.fusableUser );
178+ };
179+
180+ // If the fusable user is already a candidate, update it with the new
181+ // list of slices to handle. Otherwise, insert it into the right
182+ // position based on dominance.
183+ auto *it = llvm::lower_bound (candidates, entry, comp);
184+ if (it != candidates.end () && it->fusableUser == fusableUser)
185+ *it = std::move (entry);
186+ else
187+ candidates.insert (it, std::move (entry));
188+ }
139189 }
140190 }
141191 };
142192
143- addCandidateSlices (tiledOp);
193+ // Add slices from all tiled ops, not only the "main" one.
194+ DominanceInfo dominanceInfo;
195+ for (Operation *tiledOp : tiledOps) {
196+ addCandidateSlices (tiledOp, dominanceInfo);
197+ }
144198
145199 std::queue<Operation *> newFusionOpportunities;
146200 while (!candidates.empty ()) {
147- // Traverse the slices in BFS fashion.
148- SmallVector<Operation *> candidateSlices = candidates.front ();
149- candidates.pop ();
201+ // Get the next candidate.
202+ ConsumerFusionQueueEntry entry = candidates.pop_back_val ();
150203
151204 FailureOr<scf::SCFFuseConsumerOfSliceResult> fusedResult =
152- mlir::scf::tileAndFuseConsumerOfSlices (rewriter, candidateSlices,
153- loops);
205+ mlir::scf::tileAndFuseConsumerOfSlices (rewriter, entry.slices , loops);
154206 if (failed (fusedResult)) {
155207 return failure ();
156208 }
@@ -162,8 +214,10 @@ fuseConsumersIntoForall(RewriterBase &rewriter, Operation *tiledOp,
162214 // The result of the fused consumers might themselves be slices of
163215 // values produced by operations that implement the `TilingInterface`.
164216 // Add these operations to the worklist.
217+ DominanceInfo dominanceInfo;
165218 addCandidateSlices (
166- fusedResult->tiledAndFusedConsumerOperands .front ()->getOwner ());
219+ fusedResult->tiledAndFusedConsumerOperands .front ()->getOwner (),
220+ dominanceInfo);
167221
168222 // Add the list of new producer fusion opportunities.
169223 for (auto tiledOp : fusedResult.value ().tiledOps ) {
0 commit comments