@@ -41,7 +41,7 @@ namespace affine {
4141} // namespace affine
4242} // namespace mlir
4343
44- #define DEBUG_TYPE " affine-loop- fusion"
44+ #define DEBUG_TYPE " affine-fusion"
4545
4646using namespace mlir ;
4747using namespace mlir ::affine;
@@ -237,29 +237,67 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
237237 node->op = newRootForOp;
238238}
239239
240- // Creates and returns a private (single-user) memref for fused loop rooted
241- // at 'forOp', with (potentially reduced) memref size based on the
242- // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
243- // TODO: consider refactoring the common code from generateDma and
244- // this one.
245- static Value createPrivateMemRef (AffineForOp forOp, Operation *srcStoreOpInst,
240+ // / Get the operation that should act as a dominance filter while replacing
241+ // / memref uses with a private memref for which `producerStores` and
242+ // / `sliceInsertionBlock` are provided. This effectively determines in what
243+ // / part of the IR we should be performing the replacement.
244+ static Operation *
245+ getDominanceFilterForPrivateMemRefRepl (Block *sliceInsertionBlock,
246+ ArrayRef<Operation *> producerStores) {
247+ assert (!producerStores.empty () && " expected producer store" );
248+
249+ // We first find the common block that contains the producer stores and
250+ // the slice computation. The first ancestor among the ancestors of the
251+ // producer stores in that common block is the dominance filter to use for
252+ // replacement.
253+ Block *commonBlock = nullptr ;
254+ // Find the common block of all relevant operations.
255+ for (Operation *store : producerStores) {
256+ Operation *otherOp =
257+ !commonBlock ? &*sliceInsertionBlock->begin () : &*commonBlock->begin ();
258+ commonBlock = findInnermostCommonBlockInScope (store, otherOp);
259+ }
260+ assert (commonBlock &&
261+ " common block of producer stores and slice should exist" );
262+
263+ // Find the first ancestor among the ancestors of `producerStores` in
264+ // `commonBlock`.
265+ Operation *firstAncestor = nullptr ;
266+ for (Operation *store : producerStores) {
267+ Operation *ancestor = commonBlock->findAncestorOpInBlock (*store);
268+ assert (ancestor && " producer store should be contained in common block" );
269+ firstAncestor = !firstAncestor || ancestor->isBeforeInBlock (firstAncestor)
270+ ? ancestor
271+ : firstAncestor;
272+ }
273+ return firstAncestor;
274+ }
275+
276+ // Creates and returns a private (single-user) memref for fused loop rooted at
277+ // 'forOp', with (potentially reduced) memref size based on the memref region
278+ // written to by `storeOps` at depth 'dstLoopDepth'. 'sliceInsertionBlock'
279+ // specifies the block in which the slice was/will be inserted.
280+ static Value createPrivateMemRef (AffineForOp forOp,
281+ ArrayRef<Operation *> storeOps,
246282 unsigned dstLoopDepth,
247283 std::optional<unsigned > fastMemorySpace,
284+ Block *sliceInsertionBlock,
248285 uint64_t localBufSizeThreshold) {
249- Operation *forInst = forOp.getOperation ();
286+ assert (!storeOps.empty () && " no source stores supplied" );
287+ Operation *srcStoreOp = storeOps[0 ];
250288
251289 // Create builder to insert alloc op just before 'forOp'.
252- OpBuilder b (forInst );
290+ OpBuilder b (forOp );
253291 // Builder to create constants at the top level.
254- OpBuilder top (forInst ->getParentRegion ());
292+ OpBuilder top (forOp ->getParentRegion ());
255293 // Create new memref type based on slice bounds.
256- auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst ).getMemRef ();
294+ auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOp ).getMemRef ();
257295 auto oldMemRefType = cast<MemRefType>(oldMemRef.getType ());
258296 unsigned rank = oldMemRefType.getRank ();
259297
260298 // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
261- MemRefRegion region (srcStoreOpInst ->getLoc ());
262- bool validRegion = succeeded (region.compute (srcStoreOpInst , dstLoopDepth));
299+ MemRefRegion region (srcStoreOp ->getLoc ());
300+ bool validRegion = succeeded (region.compute (srcStoreOp , dstLoopDepth));
263301 (void )validRegion;
264302 assert (validRegion && " unexpected memref region failure" );
265303 SmallVector<int64_t , 4 > newShape;
@@ -332,11 +370,12 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
332370 AffineMap::get (outerIVs.size () + rank, 0 , remapExprs, forOp.getContext ());
333371
334372 // Replace all users of 'oldMemRef' with 'newMemRef'.
335- LogicalResult res =
336- replaceAllMemRefUsesWith (oldMemRef, newMemRef, {}, indexRemap,
337- /* extraOperands=*/ outerIVs,
338- /* symbolOperands=*/ {},
339- /* domOpFilter=*/ &*forOp.getBody ()->begin ());
373+ Operation *domFilter =
374+ getDominanceFilterForPrivateMemRefRepl (sliceInsertionBlock, storeOps);
375+ LogicalResult res = replaceAllMemRefUsesWith (
376+ oldMemRef, newMemRef, /* extraIndices=*/ {}, indexRemap,
377+ /* extraOperands=*/ outerIVs,
378+ /* symbolOperands=*/ {}, domFilter);
340379 assert (succeeded (res) &&
341380 " replaceAllMemrefUsesWith should always succeed here" );
342381 (void )res;
@@ -944,6 +983,10 @@ struct GreedyFusion {
944983
945984 // Create private memrefs.
946985 if (!privateMemrefs.empty ()) {
986+ // Note the block into which fusion was performed. This can be used to
987+ // place `alloc`s that create private memrefs.
988+ Block *sliceInsertionBlock = bestSlice.insertPoint ->getBlock ();
989+
947990 // Gather stores for all the private-to-be memrefs.
948991 DenseMap<Value, SmallVector<Operation *, 4 >> privateMemRefToStores;
949992 dstAffineForOp.walk ([&](AffineWriteOpInterface storeOp) {
@@ -962,8 +1005,8 @@ struct GreedyFusion {
9621005 SmallVector<Operation *, 4 > &storesForMemref =
9631006 memrefToStoresPair.second ;
9641007 Value newMemRef = createPrivateMemRef (
965- dstAffineForOp, storesForMemref[ 0 ] , bestDstLoopDepth,
966- fastMemorySpace, localBufSizeThreshold);
1008+ dstAffineForOp, storesForMemref, bestDstLoopDepth,
1009+ fastMemorySpace, sliceInsertionBlock, localBufSizeThreshold);
9671010 // Create new node in dependence graph for 'newMemRef' alloc op.
9681011 unsigned newMemRefNodeId = mdg->addNode (newMemRef.getDefiningOp ());
9691012 // Add edge from 'newMemRef' node to dstNode.
0 commit comments