diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h index ae5a68a6be157..ac11f5a7c24c7 100644 --- a/mlir/include/mlir/Dialect/Affine/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Utils.h @@ -198,10 +198,9 @@ AffineExpr substWithMin(AffineExpr e, AffineExpr dim, AffineExpr min, /// of its input list. `indexRemap`'s dimensional inputs are expected to /// correspond to memref's indices, and its symbolic inputs if any should be /// provided in `symbolOperands`. -/// -/// `domOpFilter`, if non-null, restricts the replacement to only those -/// operations that are dominated by the former; similarly, `postDomOpFilter` -/// restricts replacement to only those operations that are postdominated by it. +// +/// If `userFilterFn` is specified, restrict replacement to only those users +/// that pass the specified filter (i.e., the filter returns true). /// /// 'allowNonDereferencingOps', if set, allows replacement of non-dereferencing /// uses of a memref without any requirement for access index rewrites as long @@ -224,13 +223,14 @@ AffineExpr substWithMin(AffineExpr e, AffineExpr dim, AffineExpr min, // d1, d2) -> (d0 - d1, d2), and %ii will be the extra operand. Without any // extra operands, note that 'indexRemap' would just be applied to existing // indices (%i, %j). +// // TODO: allow extraIndices to be added at any position. LogicalResult replaceAllMemRefUsesWith( Value oldMemRef, Value newMemRef, ArrayRef extraIndices = {}, AffineMap indexRemap = AffineMap(), ArrayRef extraOperands = {}, - ArrayRef symbolOperands = {}, Operation *domOpFilter = nullptr, - Operation *postDomOpFilter = nullptr, bool allowNonDereferencingOps = false, - bool replaceInDeallocOp = false); + ArrayRef symbolOperands = {}, + llvm::function_ref userFilterFn = nullptr, + bool allowNonDereferencingOps = false, bool replaceInDeallocOp = false); /// Performs the same replacement as the other version above but only for the /// dereferencing uses of `oldMemRef` in `op`, except in cases where diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp index 4b4eb9ce37b4c..da05dec6e4af3 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -445,10 +445,15 @@ static Value createPrivateMemRef(AffineForOp forOp, // Replace all users of 'oldMemRef' with 'newMemRef'. Operation *domFilter = getDominanceFilterForPrivateMemRefRepl(sliceInsertionBlock, storeOps); + auto userFilterFn = [&](Operation *user) { + auto domInfo = std::make_unique( + domFilter->getParentOfType()); + return domInfo->dominates(domFilter, user); + }; LogicalResult res = replaceAllMemRefUsesWith( oldMemRef, newMemRef, /*extraIndices=*/{}, indexRemap, /*extraOperands=*/outerIVs, - /*symbolOperands=*/{}, domFilter); + /*symbolOperands=*/{}, userFilterFn); assert(succeeded(res) && "replaceAllMemrefUsesWith should always succeed here"); (void)res; diff --git a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp index 4be99aa197380..92cb7075005a3 100644 --- a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp @@ -115,13 +115,16 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) { // replaceAllMemRefUsesWith will succeed unless the forOp body has // non-dereferencing uses of the memref (dealloc's are fine though). - if (failed(replaceAllMemRefUsesWith( - oldMemRef, newMemRef, - /*extraIndices=*/{ivModTwoOp}, - /*indexRemap=*/AffineMap(), - /*extraOperands=*/{}, - /*symbolOperands=*/{}, - /*domOpFilter=*/&*forOp.getBody()->begin()))) { + auto userFilterFn = [&](Operation *user) { + auto domInfo = std::make_unique( + forOp->getParentOfType()); + return domInfo->dominates(&*forOp.getBody()->begin(), user); + }; + if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef, + /*extraIndices=*/{ivModTwoOp}, + /*indexRemap=*/AffineMap(), + /*extraOperands=*/{}, + /*symbolOperands=*/{}, userFilterFn))) { LLVM_DEBUG( forOp.emitError("memref replacement for double buffering failed")); ivModTwoOp.erase(); diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp index 4aa1fe318efa8..0501616ad912c 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -1967,6 +1967,12 @@ static LogicalResult generateCopy( if (begin == end) return success(); + // Record the last op in the block for which we are performing copy + // generation. We later do the memref replacement only in [begin, lastCopyOp] + // so that the original memref's used in the data movement code themselves + // don't get replaced. + Operation *lastCopyOp = end->getPrevNode(); + // Is the copy out point at the end of the block where we are doing // explicit copying. bool isCopyOutAtEndOfBlock = (end == copyOutPlacementStart); @@ -2143,12 +2149,6 @@ static LogicalResult generateCopy( } } - // Record the last operation where we want the memref replacement to end. We - // later do the memref replacement only in [begin, postDomFilter] so - // that the original memref's used in the data movement code themselves don't - // get replaced. - auto postDomFilter = std::prev(end); - // Create fully composed affine maps for each memref. auto memAffineMap = b.getMultiDimIdentityMap(memIndices.size()); fullyComposeAffineMapAndOperands(&memAffineMap, &memIndices); @@ -2244,13 +2244,17 @@ static LogicalResult generateCopy( if (!isBeginAtStartOfBlock) prevOfBegin = std::prev(begin); + auto userFilterFn = [&](Operation *user) { + auto *ancestorUser = block->findAncestorOpInBlock(*user); + return ancestorUser && !ancestorUser->isBeforeInBlock(&*begin) && + !lastCopyOp->isBeforeInBlock(ancestorUser); + }; + // *Only* those uses within the range [begin, end) of 'block' are replaced. (void)replaceAllMemRefUsesWith(memref, fastMemRef, /*extraIndices=*/{}, indexRemap, /*extraOperands=*/regionSymbols, - /*symbolOperands=*/{}, - /*domOpFilter=*/&*begin, - /*postDomOpFilter=*/&*postDomFilter); + /*symbolOperands=*/{}, userFilterFn); *nBegin = isBeginAtStartOfBlock ? block->begin() : std::next(prevOfBegin); diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index cde8223107859..66b3f2a4f93a5 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -1305,9 +1305,9 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith( LogicalResult mlir::affine::replaceAllMemRefUsesWith( Value oldMemRef, Value newMemRef, ArrayRef extraIndices, AffineMap indexRemap, ArrayRef extraOperands, - ArrayRef symbolOperands, Operation *domOpFilter, - Operation *postDomOpFilter, bool allowNonDereferencingOps, - bool replaceInDeallocOp) { + ArrayRef symbolOperands, + llvm::function_ref userFilterFn, + bool allowNonDereferencingOps, bool replaceInDeallocOp) { unsigned newMemRefRank = cast(newMemRef.getType()).getRank(); (void)newMemRefRank; // unused in opt mode unsigned oldMemRefRank = cast(oldMemRef.getType()).getRank(); @@ -1328,61 +1328,52 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith( std::unique_ptr domInfo; std::unique_ptr postDomInfo; - if (domOpFilter) - domInfo = std::make_unique( - domOpFilter->getParentOfType()); - - if (postDomOpFilter) - postDomInfo = std::make_unique( - postDomOpFilter->getParentOfType()); // Walk all uses of old memref; collect ops to perform replacement. We use a // DenseSet since an operation could potentially have multiple uses of a // memref (although rare), and the replacement later is going to erase ops. DenseSet opsToReplace; - for (auto *op : oldMemRef.getUsers()) { - // Skip this use if it's not dominated by domOpFilter. - if (domOpFilter && !domInfo->dominates(domOpFilter, op)) - continue; - - // Skip this use if it's not post-dominated by postDomOpFilter. - if (postDomOpFilter && !postDomInfo->postDominates(postDomOpFilter, op)) + for (auto *user : oldMemRef.getUsers()) { + // Check if this user doesn't pass the filter. + if (userFilterFn && !userFilterFn(user)) continue; // Skip dealloc's - no replacement is necessary, and a memref replacement // at other uses doesn't hurt these dealloc's. - if (hasSingleEffect(op, oldMemRef) && + if (hasSingleEffect(user, oldMemRef) && !replaceInDeallocOp) continue; // Check if the memref was used in a non-dereferencing context. It is fine // for the memref to be used in a non-dereferencing way outside of the // region where this replacement is happening. - if (!isa(*op)) { + if (!isa(*user)) { if (!allowNonDereferencingOps) { - LLVM_DEBUG(llvm::dbgs() - << "Memref replacement failed: non-deferencing memref op: \n" - << *op << '\n'); + LLVM_DEBUG( + llvm::dbgs() + << "Memref replacement failed: non-deferencing memref user: \n" + << *user << '\n'); return failure(); } // Non-dereferencing ops with the MemRefsNormalizable trait are // supported for replacement. - if (!op->hasTrait()) { + if (!user->hasTrait()) { LLVM_DEBUG(llvm::dbgs() << "Memref replacement failed: use without a " "memrefs normalizable trait: \n" - << *op << '\n'); + << *user << '\n'); return failure(); } } - // We'll first collect and then replace --- since replacement erases the op - // that has the use, and that op could be postDomFilter or domFilter itself! - opsToReplace.insert(op); + // We'll first collect and then replace --- since replacement erases the + // user that has the use, and that user could be postDomFilter or domFilter + // itself! + opsToReplace.insert(user); } - for (auto *op : opsToReplace) { + for (auto *user : opsToReplace) { if (failed(replaceAllMemRefUsesWith( - oldMemRef, newMemRef, op, extraIndices, indexRemap, extraOperands, + oldMemRef, newMemRef, user, extraIndices, indexRemap, extraOperands, symbolOperands, allowNonDereferencingOps))) llvm_unreachable("memref replacement guaranteed to succeed here"); } @@ -1763,8 +1754,7 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp allocOp) { /*indexRemap=*/layoutMap, /*extraOperands=*/{}, /*symbolOperands=*/symbolOperands, - /*domOpFilter=*/nullptr, - /*postDomOpFilter=*/nullptr, + /*userFilterFn=*/nullptr, /*allowNonDereferencingOps=*/true))) { // If it failed (due to escapes for example), bail out. newAlloc.erase(); @@ -1854,8 +1844,7 @@ mlir::affine::normalizeMemRef(memref::ReinterpretCastOp reinterpretCastOp) { /*indexRemap=*/oldLayoutMap, /*extraOperands=*/{}, /*symbolOperands=*/oldStrides, - /*domOpFilter=*/nullptr, - /*postDomOpFilter=*/nullptr, + /*userFilterFn=*/nullptr, /*allowNonDereferencingOps=*/true))) { // If it failed (due to escapes for example), bail out. newReinterpretCast.erase(); diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp index b408962690810..d6fcb8d9f0501 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp @@ -297,8 +297,7 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp, /*indexRemap=*/layoutMap, /*extraOperands=*/{}, /*symbolOperands=*/{}, - /*domOpFilter=*/nullptr, - /*postDomOpFilter=*/nullptr, + /*userFilterFn=*/nullptr, /*allowNonDereferencingOps=*/true, /*replaceInDeallocOp=*/true))) { // If it failed (due to escapes for example), bail out. @@ -407,8 +406,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp, /*indexRemap=*/layoutMap, /*extraOperands=*/{}, /*symbolOperands=*/{}, - /*domOpFilter=*/nullptr, - /*postDomOpFilter=*/nullptr, + /*userFilterFn=*/nullptr, /*allowNonDereferencingOps=*/true, /*replaceInDeallocOp=*/true))) { // If it failed (due to escapes for example), bail out. Removing the @@ -457,8 +455,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp, /*indexRemap=*/layoutMap, /*extraOperands=*/{}, /*symbolOperands=*/{}, - /*domOpFilter=*/nullptr, - /*postDomOpFilter=*/nullptr, + /*userFilterFn=*/nullptr, /*allowNonDereferencingOps=*/true, /*replaceInDeallocOp=*/true))) { newOp->erase(); diff --git a/mlir/test/Dialect/Affine/affine-data-copy.mlir b/mlir/test/Dialect/Affine/affine-data-copy.mlir index a1f0d952e7c63..a745271eb9ca8 100644 --- a/mlir/test/Dialect/Affine/affine-data-copy.mlir +++ b/mlir/test/Dialect/Affine/affine-data-copy.mlir @@ -447,3 +447,51 @@ func.func @memref_def_inside(%arg0: index) { // LIMITED-MEM-NEXT: memref.dealloc %{{.*}} : memref<1xf32> return } + +// Test with uses across multiple blocks. + +memref.global "private" constant @__constant_1x2x1xi32_1 : memref<1x2x1xi32> = dense<0> {alignment = 64 : i64} + +// CHECK-LABEL: func @multiple_blocks +func.func @multiple_blocks(%arg0: index) -> memref<1x2x1xi32> { + %c1_i32 = arith.constant 1 : i32 + %c3_i32 = arith.constant 3 : i32 + %0 = memref.get_global @__constant_1x2x1xi32_1 : memref<1x2x1xi32> + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x2x1xi32> + memref.copy %0, %alloc : memref<1x2x1xi32> to memref<1x2x1xi32> + cf.br ^bb1(%alloc : memref<1x2x1xi32>) +^bb1(%1: memref<1x2x1xi32>): // 2 preds: ^bb0, ^bb2 +// CHECK: ^bb1(%[[MEM:.*]]: memref<1x2x1xi32>): + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x2x1xi1> + // CHECK: %[[BUF:.*]] = memref.alloc() : memref<1x2x1xi32> + affine.for %arg1 = 0 to 1 { + affine.for %arg2 = 0 to 2 { + affine.for %arg3 = 0 to 1 { + // CHECK: affine.load %[[BUF]] + %3 = affine.load %1[%arg1, %arg2, %arg3] : memref<1x2x1xi32> + %4 = arith.cmpi slt, %3, %c3_i32 : i32 + affine.store %4, %alloc_0[%arg1, %arg2, %arg3] : memref<1x2x1xi1> + } + } + } + // CHECK: memref.dealloc %[[BUF]] + %2 = memref.load %alloc_0[%arg0, %arg0, %arg0] : memref<1x2x1xi1> + cf.cond_br %2, ^bb2, ^bb3 +^bb2: // pred: ^bb1 +// CHECK: ^bb2 + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x2x1xi32> + affine.for %arg1 = 0 to 1 { + affine.for %arg2 = 0 to 2 { + affine.for %arg3 = 0 to 1 { + // Ensure that this reference isn't replaced. + %3 = affine.load %1[%arg1, %arg2, %arg3] : memref<1x2x1xi32> + // CHECK: affine.load %[[MEM]] + %4 = arith.addi %3, %c1_i32 : i32 + affine.store %4, %alloc_1[%arg1, %arg2, %arg3] : memref<1x2x1xi32> + } + } + } + cf.br ^bb1(%alloc_1 : memref<1x2x1xi32>) +^bb3: // pred: ^bb1 + return %1 : memref<1x2x1xi32> +}