@@ -41,7 +41,7 @@ namespace affine {
4141} // namespace affine
4242} // namespace mlir
4343
44- #define DEBUG_TYPE " affine-loop- fusion"
44+ #define DEBUG_TYPE " affine-fusion"
4545
4646using namespace mlir ;
4747using namespace mlir ::affine;
@@ -237,29 +237,71 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
237237 node->op = newRootForOp;
238238}
239239
240+ // / Get the operation that should act as a dominance filter while replacing
241+ // / memref uses with a private memref for which `producerStores` and
242+ // / `sliceInsertionBlock` are provided. This effectively determines in what
243+ // / part of the IR we should be performing the replacement.
244+ static Operation *
245+ getDominanceFilterForPrivateMemRefRepl (Block *sliceInsertionBlock,
246+ ArrayRef<Operation *> producerStores) {
247+ assert (!producerStores.empty () && " expected producer store" );
248+
249+ // We first find the common block that contains the producer stores and
250+ // the slice computation. The first ancestor among the ancestors of the
251+ // producer stores in that common block is the dominance filter to use for
252+ // replacement.
253+ Block *commonBlock = nullptr ;
254+ // Find the common block of all relevant operations.
255+ for (Operation *store : producerStores) {
256+ if (!commonBlock)
257+ commonBlock = findInnermostCommonBlockInScope (
258+ store, &*sliceInsertionBlock->begin ());
259+ else
260+ commonBlock =
261+ findInnermostCommonBlockInScope (store, &*commonBlock->begin ());
262+ }
263+ assert (commonBlock &&
264+ " common block of producer stores and slice should exist" );
265+
266+ // Find the first ancestor among the ancestors of `producerStores` in
267+ // `commonBlock`.
268+ Operation *firstAncestor = nullptr ;
269+ for (Operation *store : producerStores) {
270+ Operation *ancestor = commonBlock->findAncestorOpInBlock (*store);
271+ assert (ancestor && " producer store should be contained in common block" );
272+ firstAncestor = !firstAncestor || ancestor->isBeforeInBlock (firstAncestor)
273+ ? ancestor
274+ : firstAncestor;
275+ }
276+ return firstAncestor;
277+ }
278+
240279// Creates and returns a private (single-user) memref for fused loop rooted
241280// at 'forOp', with (potentially reduced) memref size based on the
242281// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
243282// TODO: consider refactoring the common code from generateDma and
244283// this one.
245- static Value createPrivateMemRef (AffineForOp forOp, Operation *srcStoreOpInst,
284+ static Value createPrivateMemRef (AffineForOp forOp,
285+ ArrayRef<Operation *> storeOps,
246286 unsigned dstLoopDepth,
247287 std::optional<unsigned > fastMemorySpace,
288+ Block *sliceInsertionBlock,
248289 uint64_t localBufSizeThreshold) {
249- Operation *forInst = forOp.getOperation ();
290+ assert (!storeOps.empty () && " no source stores supplied" );
291+ Operation *srcStoreOp = storeOps[0 ];
250292
251293 // Create builder to insert alloc op just before 'forOp'.
252- OpBuilder b (forInst );
294+ OpBuilder b (forOp );
253295 // Builder to create constants at the top level.
254- OpBuilder top (forInst ->getParentRegion ());
296+ OpBuilder top (forOp ->getParentRegion ());
255297 // Create new memref type based on slice bounds.
256- auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst ).getMemRef ();
298+ auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOp ).getMemRef ();
257299 auto oldMemRefType = cast<MemRefType>(oldMemRef.getType ());
258300 unsigned rank = oldMemRefType.getRank ();
259301
260302 // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
261- MemRefRegion region (srcStoreOpInst ->getLoc ());
262- bool validRegion = succeeded (region.compute (srcStoreOpInst , dstLoopDepth));
303+ MemRefRegion region (srcStoreOp ->getLoc ());
304+ bool validRegion = succeeded (region.compute (srcStoreOp , dstLoopDepth));
263305 (void )validRegion;
264306 assert (validRegion && " unexpected memref region failure" );
265307 SmallVector<int64_t , 4 > newShape;
@@ -332,11 +374,12 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
332374 AffineMap::get (outerIVs.size () + rank, 0 , remapExprs, forOp.getContext ());
333375
334376 // Replace all users of 'oldMemRef' with 'newMemRef'.
377+ Operation *domFilter =
378+ getDominanceFilterForPrivateMemRefRepl (sliceInsertionBlock, storeOps);
335379 LogicalResult res =
336380 replaceAllMemRefUsesWith (oldMemRef, newMemRef, {}, indexRemap,
337381 /* extraOperands=*/ outerIVs,
338- /* symbolOperands=*/ {},
339- /* domOpFilter=*/ &*forOp.getBody ()->begin ());
382+ /* symbolOperands=*/ {}, domFilter);
340383 assert (succeeded (res) &&
341384 " replaceAllMemrefUsesWith should always succeed here" );
342385 (void )res;
@@ -944,6 +987,10 @@ struct GreedyFusion {
944987
945988 // Create private memrefs.
946989 if (!privateMemrefs.empty ()) {
990+ // Note the block into which fusion was performed. This can be used to
991+ // place `alloc`s that create private memrefs.
992+ Block *sliceInsertionBlock = bestSlice.insertPoint ->getBlock ();
993+
947994 // Gather stores for all the private-to-be memrefs.
948995 DenseMap<Value, SmallVector<Operation *, 4 >> privateMemRefToStores;
949996 dstAffineForOp.walk ([&](AffineWriteOpInterface storeOp) {
@@ -962,8 +1009,8 @@ struct GreedyFusion {
9621009 SmallVector<Operation *, 4 > &storesForMemref =
9631010 memrefToStoresPair.second ;
9641011 Value newMemRef = createPrivateMemRef (
965- dstAffineForOp, storesForMemref[ 0 ] , bestDstLoopDepth,
966- fastMemorySpace, localBufSizeThreshold);
1012+ dstAffineForOp, storesForMemref, bestDstLoopDepth,
1013+ fastMemorySpace, sliceInsertionBlock, localBufSizeThreshold);
9671014 // Create new node in dependence graph for 'newMemRef' alloc op.
9681015 unsigned newMemRefNodeId = mdg->addNode (newMemRef.getDefiningOp ());
9691016 // Add edge from 'newMemRef' node to dstNode.
0 commit comments