88#include " llvm/ADT/STLExtras.h"
99#include " llvm/Support/Casting.h"
1010#include " llvm/Support/Debug.h"
11+ #include " mlir/Analysis/TopologicalSortUtils.h"
1112#include " mlir/Dialect/Linalg/IR/Linalg.h"
1213#include " mlir/Dialect/Tensor/IR/Tensor.h"
1314
@@ -51,36 +52,6 @@ void fuseProducersOfSlices(RewriterBase &rewriter,
5152 }
5253}
5354
54- bool warForConsumerFusionSSAViolation (
55- Operation *rootOp,
56- const llvm::SmallDenseSet<Operation *> &tiledAndFusedOps) {
57- auto linalgRootOp = dyn_cast<linalg::LinalgOp>(rootOp);
58- if (!linalgRootOp) {
59- return false ;
60- }
61- SmallVector<utils::IteratorType> iteratorTypes =
62- linalgRootOp.getIteratorTypesArray ();
63- for (AffineMap map :
64- llvm::map_range (linalgRootOp.getIndexingMaps (), [](Attribute attr) {
65- return cast<AffineMapAttr>(attr).getValue ();
66- })) {
67- if (!compressUnusedDims (map).isIdentity ()) {
68- return false ;
69- }
70- }
71-
72- for (OpOperand &use : linalgRootOp->getUses ()) {
73- auto linalgUser = dyn_cast<linalg::LinalgOp>(use.getOwner ());
74- if (!linalgUser) {
75- return false ;
76- }
77- if (!linalgUser.getMatchingIndexingMap (&use).isIdentity ()) {
78- return false ;
79- }
80- }
81- return true ;
82- }
83-
8455void collectTiledAndFusedOps (Operation *rootOp,
8556 llvm::SmallDenseSet<Operation *> &result) {
8657 SmallVector<Operation *> worklist;
@@ -111,82 +82,72 @@ void collectTiledAndFusedOps(Operation *rootOp,
11182}
11283
11384FailureOr<std::queue<Operation *>>
114- fuseConsumersIntoLoops (RewriterBase &rewriter, Operation *tiledOp,
115- MutableArrayRef<LoopLikeOpInterface> loops,
116- bool useWARForConsumerFusionSSAViolation) {
117- auto addCandidateSlices = [](Operation *fusedOp,
118- std::queue<Operation *> &candidates) {
85+ fuseConsumersIntoForall (RewriterBase &rewriter, Operation *tiledOp,
86+ MutableArrayRef<LoopLikeOpInterface> loops,
87+ std::function<bool (Operation *)> filterFn) {
88+ // Collect the candidate slices which can be potential consumers that can be
89+ // fused.
90+ std::queue<SmallVector<Operation *>> candidates;
91+ llvm::SmallDenseSet<tensor::ParallelInsertSliceOp> allCandidates;
92+ auto addCandidateSlices = [&candidates, &allCandidates,
93+ &filterFn](Operation *fusedOp) {
11994 for (auto *userOp : fusedOp->getResults ().getUsers ()) {
120- if (llvm::isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
121- userOp)) {
122- // Users of tiledOp should either be all of type `tensor.insert_slice`
123- // or all of`tensor.parallel_insert_slice`.
124- //
125- // Pattern 1 - tileing with scf.for:
126- // %out = scf.for ... {
127- // %0 = scf.for ... {
128- // %t0 = op
129- // %t1 = op %t0 // <- `tiledOp`
130- // %1 = tensor.insert_slice %t1
131- // yield %1
132- // }
133- // yield %0
134- // }
135- //
136- // Pattern 2 - tiling with scf.forall:
137- // % out = scf.forall ... {
138- // %t0 = op
139- // %t1 = op %t0 // <- `tiledOp`
140- // scf.forall.in_parallel {
141- // tensor.parallel_insert_slice %tile
142- // }
143- // }
144- assert ((candidates.empty () ||
145- candidates.front ()->getName () == userOp->getName ()) &&
146- " expected all slice users to be of type tensor.insert_slice "
147- " or of tensor.parallel_insert_slice." );
148- candidates.push (userOp);
95+ auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(userOp);
96+ if (!sliceOp || allCandidates.contains (sliceOp)) {
97+ continue ;
98+ }
99+
100+ auto currLoop =
101+ cast<scf::ForallOp>(sliceOp->getParentOp ()->getParentOp ());
102+ OpResult loopResult = currLoop.getTiedOpResult (
103+ currLoop.getTiedOpOperand (cast<BlockArgument>(sliceOp.getDest ())));
104+ SmallVector<Operation *> users = llvm::to_vector (
105+ llvm::make_filter_range (loopResult.getUsers (), filterFn));
106+ if (users.empty ()) {
107+ continue ;
108+ }
109+ mlir::computeTopologicalSorting (users);
110+
111+ Operation *fusableUser = users.front ();
112+ // Check all operands from the `scf.forall`
113+ SmallVector<OpResult> loopResults;
114+ for (OpOperand &opOperand : fusableUser->getOpOperands ()) {
115+ if (opOperand.get ().getDefiningOp () == currLoop.getOperation ()) {
116+ loopResults.push_back (cast<OpResult>(opOperand.get ()));
117+ }
118+ }
119+
120+ SmallVector<Operation *> fusedSlices;
121+ for (OpResult result : loopResults) {
122+ BlockArgument tiedBlockArg =
123+ currLoop.getTiedBlockArgument (currLoop.getTiedOpOperand (result));
124+ SmallVector<tensor::ParallelInsertSliceOp> slices = llvm::map_to_vector (
125+ currLoop.getCombiningOps (tiedBlockArg), [](Operation *op) {
126+ return cast<tensor::ParallelInsertSliceOp>(op);
127+ });
128+ llvm::append_range (fusedSlices, slices);
129+ allCandidates.insert_range (slices);
130+ }
131+ if (!fusedSlices.empty ()) {
132+ candidates.emplace (std::move (fusedSlices));
149133 }
150134 }
151135 };
152136
153- // Collect the candidate slices which can be potential consumers that can be
154- // fused.
155- std::queue<Operation *> candidates;
156- addCandidateSlices (tiledOp, candidates);
137+ addCandidateSlices (tiledOp);
157138
158139 std::queue<Operation *> newFusionOpportunities;
159140 while (!candidates.empty ()) {
160141 // Traverse the slices in BFS fashion.
161- Operation *candidateSliceOp = candidates.front ();
142+ SmallVector< Operation *> candidateSlices = candidates.front ();
162143 candidates.pop ();
163144
164145 FailureOr<scf::SCFFuseConsumerOfSliceResult> fusedResult =
165- mlir::scf::tileAndFuseConsumerOfSlices (rewriter, candidateSliceOp ,
146+ mlir::scf::tileAndFuseConsumerOfSlices (rewriter, candidateSlices ,
166147 loops);
167148 if (failed (fusedResult)) {
168- LLVM_DEBUG (llvm::dbgs () << " failed to fuse consumer of slice: "
169- << candidateSliceOp << " \n " );
170- continue ;
171- }
172-
173- // Implement the WAR for consumer fusion SSA violation (as described in the
174- // comments for `warForConsumerFusionSSAViolation`)
175- if (useWARForConsumerFusionSSAViolation) {
176- for (auto [tiledOpResult, loopResult] :
177- llvm::zip (tiledOp->getResults (), loops.back ()->getResults ())) {
178- for (OpOperand &use : loopResult.getUses ()) {
179- Operation *user = use.getOwner ();
180- if (user->getParentOp () != loops.back ()) {
181- continue ;
182- }
183- auto slice = dyn_cast<tensor::ExtractSliceOp>(user);
184- if (!slice) {
185- return failure ();
186- }
187- rewriter.replaceAllOpUsesWith (slice, tiledOpResult);
188- }
189- }
149+ return candidateSlices.front ()->emitOpError (
150+ " failed to fuse consumer of slice" );
190151 }
191152
192153 // Replace the original consumer operation with the tiled implementation.
@@ -197,8 +158,7 @@ fuseConsumersIntoLoops(RewriterBase &rewriter, Operation *tiledOp,
197158 // values produced by operations that implement the `TilingInterface`.
198159 // Add these operations to the worklist.
199160 addCandidateSlices (
200- fusedResult->tiledAndFusedConsumerOperands .front ()->getOwner (),
201- candidates);
161+ fusedResult->tiledAndFusedConsumerOperands .front ()->getOwner ());
202162
203163 // Add the list of new producer fusion opportunities.
204164 for (auto tiledOp : fusedResult.value ().tiledOps ) {
0 commit comments