diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h index b1fbf4477428c..7164ade6ea53a 100644 --- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h @@ -610,6 +610,14 @@ FailureOr simplifyConstrainedMinMaxOp(Operation *op, FlatAffineValueConstraints constraints); +/// Find the innermost common `Block` of `a` and `b` in the affine scope +/// that `a` and `b` are part of. Return nullptr if they belong to different +/// affine scopes. Also, return nullptr if they do not have a common `Block` +/// ancestor (for eg., when they are part of the `then` and `else` regions +/// of an op that itself starts an affine scope. +mlir::Block *findInnermostCommonBlockInScope(mlir::Operation *a, + mlir::Operation *b); + } // namespace affine } // namespace mlir diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp index 9c0b5dbf52d29..10de0d04cbea6 100644 --- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/Analysis/Utils.h" + #include "mlir/Analysis/Presburger/PresburgerRelation.h" #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" @@ -2297,3 +2298,41 @@ FailureOr mlir::affine::simplifyConstrainedMinMaxOp( affine::canonicalizeMapAndOperands(&newMap, &newOperands); return AffineValueMap(newMap, newOperands); } + +Block *mlir::affine::findInnermostCommonBlockInScope(Operation *a, + Operation *b) { + Region *aScope = mlir::affine::getAffineScope(a); + Region *bScope = mlir::affine::getAffineScope(b); + if (aScope != bScope) + return nullptr; + + // Get the block ancestry of `op` while stopping at the affine scope `aScope` + // and store them in `ancestry`. + auto getBlockAncestry = [&](Operation *op, + SmallVectorImpl &ancestry) { + Operation *curOp = op; + do { + ancestry.push_back(curOp->getBlock()); + if (curOp->getParentRegion() == aScope) + break; + curOp = curOp->getParentOp(); + } while (curOp); + assert(curOp && "can't reach root op without passing through affine scope"); + std::reverse(ancestry.begin(), ancestry.end()); + }; + + SmallVector aAncestors, bAncestors; + getBlockAncestry(a, aAncestors); + getBlockAncestry(b, bAncestors); + assert(!aAncestors.empty() && !bAncestors.empty() && + "at least one Block ancestor expected"); + + Block *innermostCommonBlock = nullptr; + for (unsigned a = 0, b = 0, e = aAncestors.size(), f = bAncestors.size(); + a < e && b < f; ++a, ++b) { + if (aAncestors[a] != bAncestors[b]) + break; + innermostCommonBlock = aAncestors[a]; + } + return innermostCommonBlock; +} diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp index c22ec213be95c..fe6cf0f434cb7 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -41,7 +41,7 @@ namespace affine { } // namespace affine } // namespace mlir -#define DEBUG_TYPE "affine-loop-fusion" +#define DEBUG_TYPE "affine-fusion" using namespace mlir; using namespace mlir::affine; @@ -237,29 +237,67 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) { node->op = newRootForOp; } -// Creates and returns a private (single-user) memref for fused loop rooted -// at 'forOp', with (potentially reduced) memref size based on the -// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'. -// TODO: consider refactoring the common code from generateDma and -// this one. -static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, +/// Get the operation that should act as a dominance filter while replacing +/// memref uses with a private memref for which `producerStores` and +/// `sliceInsertionBlock` are provided. This effectively determines in what +/// part of the IR we should be performing the replacement. +static Operation * +getDominanceFilterForPrivateMemRefRepl(Block *sliceInsertionBlock, + ArrayRef producerStores) { + assert(!producerStores.empty() && "expected producer store"); + + // We first find the common block that contains the producer stores and + // the slice computation. The first ancestor among the ancestors of the + // producer stores in that common block is the dominance filter to use for + // replacement. + Block *commonBlock = nullptr; + // Find the common block of all relevant operations. + for (Operation *store : producerStores) { + Operation *otherOp = + !commonBlock ? &*sliceInsertionBlock->begin() : &*commonBlock->begin(); + commonBlock = findInnermostCommonBlockInScope(store, otherOp); + } + assert(commonBlock && + "common block of producer stores and slice should exist"); + + // Find the first ancestor among the ancestors of `producerStores` in + // `commonBlock`. + Operation *firstAncestor = nullptr; + for (Operation *store : producerStores) { + Operation *ancestor = commonBlock->findAncestorOpInBlock(*store); + assert(ancestor && "producer store should be contained in common block"); + firstAncestor = !firstAncestor || ancestor->isBeforeInBlock(firstAncestor) + ? ancestor + : firstAncestor; + } + return firstAncestor; +} + +// Creates and returns a private (single-user) memref for fused loop rooted at +// 'forOp', with (potentially reduced) memref size based on the memref region +// written to by `storeOps` at depth 'dstLoopDepth'. 'sliceInsertionBlock' +// specifies the block in which the slice was/will be inserted. +static Value createPrivateMemRef(AffineForOp forOp, + ArrayRef storeOps, unsigned dstLoopDepth, std::optional fastMemorySpace, + Block *sliceInsertionBlock, uint64_t localBufSizeThreshold) { - Operation *forInst = forOp.getOperation(); + assert(!storeOps.empty() && "no source stores supplied"); + Operation *srcStoreOp = storeOps[0]; // Create builder to insert alloc op just before 'forOp'. - OpBuilder b(forInst); + OpBuilder b(forOp); // Builder to create constants at the top level. - OpBuilder top(forInst->getParentRegion()); + OpBuilder top(forOp->getParentRegion()); // Create new memref type based on slice bounds. - auto oldMemRef = cast(srcStoreOpInst).getMemRef(); + auto oldMemRef = cast(srcStoreOp).getMemRef(); auto oldMemRefType = cast(oldMemRef.getType()); unsigned rank = oldMemRefType.getRank(); // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'. - MemRefRegion region(srcStoreOpInst->getLoc()); - bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth)); + MemRefRegion region(srcStoreOp->getLoc()); + bool validRegion = succeeded(region.compute(srcStoreOp, dstLoopDepth)); (void)validRegion; assert(validRegion && "unexpected memref region failure"); SmallVector newShape; @@ -332,11 +370,12 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst, AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext()); // Replace all users of 'oldMemRef' with 'newMemRef'. - LogicalResult res = - replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap, - /*extraOperands=*/outerIVs, - /*symbolOperands=*/{}, - /*domOpFilter=*/&*forOp.getBody()->begin()); + Operation *domFilter = + getDominanceFilterForPrivateMemRefRepl(sliceInsertionBlock, storeOps); + LogicalResult res = replaceAllMemRefUsesWith( + oldMemRef, newMemRef, /*extraIndices=*/{}, indexRemap, + /*extraOperands=*/outerIVs, + /*symbolOperands=*/{}, domFilter); assert(succeeded(res) && "replaceAllMemrefUsesWith should always succeed here"); (void)res; @@ -944,6 +983,10 @@ struct GreedyFusion { // Create private memrefs. if (!privateMemrefs.empty()) { + // Note the block into which fusion was performed. This can be used to + // place `alloc`s that create private memrefs. + Block *sliceInsertionBlock = bestSlice.insertPoint->getBlock(); + // Gather stores for all the private-to-be memrefs. DenseMap> privateMemRefToStores; dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) { @@ -962,8 +1005,8 @@ struct GreedyFusion { SmallVector &storesForMemref = memrefToStoresPair.second; Value newMemRef = createPrivateMemRef( - dstAffineForOp, storesForMemref[0], bestDstLoopDepth, - fastMemorySpace, localBufSizeThreshold); + dstAffineForOp, storesForMemref, bestDstLoopDepth, + fastMemorySpace, sliceInsertionBlock, localBufSizeThreshold); // Create new node in dependence graph for 'newMemRef' alloc op. unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp()); // Add edge from 'newMemRef' node to dstNode. diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir index ea144f73bb21c..2830235431c76 100644 --- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir +++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir @@ -285,3 +285,63 @@ module { spirv.ReturnValue %3 : !spirv.array<8192 x f32> } } + +// ----- + +// PRODUCER-CONSUMER-LABEL: func @same_memref_load_store +func.func @same_memref_load_store(%producer : memref<32xf32>, %consumer: memref<16xf32>){ + %cst = arith.constant 2.000000e+00 : f32 + // Source isn't removed. + // PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 32 + affine.for %arg3 = 0 to 32 { + %0 = affine.load %producer[%arg3] : memref<32xf32> + %2 = arith.mulf %0, %cst : f32 + affine.store %2, %producer[%arg3] : memref<32xf32> + } + affine.for %arg3 = 0 to 16 { + %0 = affine.load %producer[%arg3] : memref<32xf32> + %2 = arith.addf %0, %cst : f32 + affine.store %2, %consumer[%arg3] : memref<16xf32> + } + // Fused nest. + // PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 16 + // PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<32xf32> + // PRODUCER-CONSUMER-NEXT: arith.mulf + // PRODUCER-CONSUMER-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> + // PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[0] : memref<1xf32> + // PRODUCER-CONSUMER-NEXT: arith.addf + // PRODUCER-CONSUMER-NEXT: affine.store + // PRODUCER-CONSUMER-NEXT: } + return +} + +// PRODUCER-CONSUMER-LABEL: func @same_memref_load_multiple_stores +func.func @same_memref_load_multiple_stores(%producer : memref<32xf32>, %producer_2 : memref<32xf32>, %consumer: memref<16xf32>){ + %cst = arith.constant 2.000000e+00 : f32 + // Source isn't removed. + // PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 32 + affine.for %arg3 = 0 to 32 { + %0 = affine.load %producer[%arg3] : memref<32xf32> + %2 = arith.mulf %0, %cst : f32 + affine.store %2, %producer[%arg3] : memref<32xf32> + affine.store %2, %producer_2[%arg3] : memref<32xf32> + } + affine.for %arg3 = 0 to 16 { + %0 = affine.load %producer[%arg3] : memref<32xf32> + %1 = affine.load %producer_2[%arg3] : memref<32xf32> + %2 = arith.addf %0, %1 : f32 + affine.store %2, %consumer[%arg3] : memref<16xf32> + } + // Fused nest. + // PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 16 + // PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<32xf32> + // PRODUCER-CONSUMER-NEXT: arith.mulf + // PRODUCER-CONSUMER-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> + // PRODUCER-CONSUMER-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> + // PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[0] : memref<1xf32> + // PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[0] : memref<1xf32> + // PRODUCER-CONSUMER-NEXT: arith.addf + // PRODUCER-CONSUMER-NEXT: affine.store + // PRODUCER-CONSUMER-NEXT: } + return +}