Skip to content

Commit 80625c1

Browse files
authored
[MLIR][Affine] Fix memref replacement in affine-data-copy-generate (llvm#139016)
Fixes: llvm#130257 Fix affine-data-copy-generate in certain cases that involved users in multiple blocks. Perform the memref replacement correctly during copy generation. Improve/clean up memref affine use replacement API. Instead of supporting dominance and post dominance filters (which aren't adequate in most cases) and computing dominance info expensively each time in RAMUW, provide a user filter callback, i.e., force users to compute dominance if needed.
1 parent 28d4cc6 commit 80625c1

File tree

7 files changed

+109
-63
lines changed

7 files changed

+109
-63
lines changed

mlir/include/mlir/Dialect/Affine/Utils.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -198,10 +198,9 @@ AffineExpr substWithMin(AffineExpr e, AffineExpr dim, AffineExpr min,
198198
/// of its input list. `indexRemap`'s dimensional inputs are expected to
199199
/// correspond to memref's indices, and its symbolic inputs if any should be
200200
/// provided in `symbolOperands`.
201-
///
202-
/// `domOpFilter`, if non-null, restricts the replacement to only those
203-
/// operations that are dominated by the former; similarly, `postDomOpFilter`
204-
/// restricts replacement to only those operations that are postdominated by it.
201+
//
202+
/// If `userFilterFn` is specified, restrict replacement to only those users
203+
/// that pass the specified filter (i.e., the filter returns true).
205204
///
206205
/// 'allowNonDereferencingOps', if set, allows replacement of non-dereferencing
207206
/// uses of a memref without any requirement for access index rewrites as long
@@ -224,13 +223,14 @@ AffineExpr substWithMin(AffineExpr e, AffineExpr dim, AffineExpr min,
224223
// d1, d2) -> (d0 - d1, d2), and %ii will be the extra operand. Without any
225224
// extra operands, note that 'indexRemap' would just be applied to existing
226225
// indices (%i, %j).
226+
//
227227
// TODO: allow extraIndices to be added at any position.
228228
LogicalResult replaceAllMemRefUsesWith(
229229
Value oldMemRef, Value newMemRef, ArrayRef<Value> extraIndices = {},
230230
AffineMap indexRemap = AffineMap(), ArrayRef<Value> extraOperands = {},
231-
ArrayRef<Value> symbolOperands = {}, Operation *domOpFilter = nullptr,
232-
Operation *postDomOpFilter = nullptr, bool allowNonDereferencingOps = false,
233-
bool replaceInDeallocOp = false);
231+
ArrayRef<Value> symbolOperands = {},
232+
llvm::function_ref<bool(Operation *)> userFilterFn = nullptr,
233+
bool allowNonDereferencingOps = false, bool replaceInDeallocOp = false);
234234

235235
/// Performs the same replacement as the other version above but only for the
236236
/// dereferencing uses of `oldMemRef` in `op`, except in cases where

mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,10 +445,15 @@ static Value createPrivateMemRef(AffineForOp forOp,
445445
// Replace all users of 'oldMemRef' with 'newMemRef'.
446446
Operation *domFilter =
447447
getDominanceFilterForPrivateMemRefRepl(sliceInsertionBlock, storeOps);
448+
auto userFilterFn = [&](Operation *user) {
449+
auto domInfo = std::make_unique<DominanceInfo>(
450+
domFilter->getParentOfType<FunctionOpInterface>());
451+
return domInfo->dominates(domFilter, user);
452+
};
448453
LogicalResult res = replaceAllMemRefUsesWith(
449454
oldMemRef, newMemRef, /*extraIndices=*/{}, indexRemap,
450455
/*extraOperands=*/outerIVs,
451-
/*symbolOperands=*/{}, domFilter);
456+
/*symbolOperands=*/{}, userFilterFn);
452457
assert(succeeded(res) &&
453458
"replaceAllMemrefUsesWith should always succeed here");
454459
(void)res;

mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,16 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
115115

116116
// replaceAllMemRefUsesWith will succeed unless the forOp body has
117117
// non-dereferencing uses of the memref (dealloc's are fine though).
118-
if (failed(replaceAllMemRefUsesWith(
119-
oldMemRef, newMemRef,
120-
/*extraIndices=*/{ivModTwoOp},
121-
/*indexRemap=*/AffineMap(),
122-
/*extraOperands=*/{},
123-
/*symbolOperands=*/{},
124-
/*domOpFilter=*/&*forOp.getBody()->begin()))) {
118+
auto userFilterFn = [&](Operation *user) {
119+
auto domInfo = std::make_unique<DominanceInfo>(
120+
forOp->getParentOfType<FunctionOpInterface>());
121+
return domInfo->dominates(&*forOp.getBody()->begin(), user);
122+
};
123+
if (failed(replaceAllMemRefUsesWith(oldMemRef, newMemRef,
124+
/*extraIndices=*/{ivModTwoOp},
125+
/*indexRemap=*/AffineMap(),
126+
/*extraOperands=*/{},
127+
/*symbolOperands=*/{}, userFilterFn))) {
125128
LLVM_DEBUG(
126129
forOp.emitError("memref replacement for double buffering failed"));
127130
ivModTwoOp.erase();

mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1967,6 +1967,12 @@ static LogicalResult generateCopy(
19671967
if (begin == end)
19681968
return success();
19691969

1970+
// Record the last op in the block for which we are performing copy
1971+
// generation. We later do the memref replacement only in [begin, lastCopyOp]
1972+
// so that the original memref's used in the data movement code themselves
1973+
// don't get replaced.
1974+
Operation *lastCopyOp = end->getPrevNode();
1975+
19701976
// Is the copy out point at the end of the block where we are doing
19711977
// explicit copying.
19721978
bool isCopyOutAtEndOfBlock = (end == copyOutPlacementStart);
@@ -2143,12 +2149,6 @@ static LogicalResult generateCopy(
21432149
}
21442150
}
21452151

2146-
// Record the last operation where we want the memref replacement to end. We
2147-
// later do the memref replacement only in [begin, postDomFilter] so
2148-
// that the original memref's used in the data movement code themselves don't
2149-
// get replaced.
2150-
auto postDomFilter = std::prev(end);
2151-
21522152
// Create fully composed affine maps for each memref.
21532153
auto memAffineMap = b.getMultiDimIdentityMap(memIndices.size());
21542154
fullyComposeAffineMapAndOperands(&memAffineMap, &memIndices);
@@ -2244,13 +2244,17 @@ static LogicalResult generateCopy(
22442244
if (!isBeginAtStartOfBlock)
22452245
prevOfBegin = std::prev(begin);
22462246

2247+
auto userFilterFn = [&](Operation *user) {
2248+
auto *ancestorUser = block->findAncestorOpInBlock(*user);
2249+
return ancestorUser && !ancestorUser->isBeforeInBlock(&*begin) &&
2250+
!lastCopyOp->isBeforeInBlock(ancestorUser);
2251+
};
2252+
22472253
// *Only* those uses within the range [begin, end) of 'block' are replaced.
22482254
(void)replaceAllMemRefUsesWith(memref, fastMemRef,
22492255
/*extraIndices=*/{}, indexRemap,
22502256
/*extraOperands=*/regionSymbols,
2251-
/*symbolOperands=*/{},
2252-
/*domOpFilter=*/&*begin,
2253-
/*postDomOpFilter=*/&*postDomFilter);
2257+
/*symbolOperands=*/{}, userFilterFn);
22542258

22552259
*nBegin = isBeginAtStartOfBlock ? block->begin() : std::next(prevOfBegin);
22562260

mlir/lib/Dialect/Affine/Utils/Utils.cpp

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1305,9 +1305,9 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
13051305
LogicalResult mlir::affine::replaceAllMemRefUsesWith(
13061306
Value oldMemRef, Value newMemRef, ArrayRef<Value> extraIndices,
13071307
AffineMap indexRemap, ArrayRef<Value> extraOperands,
1308-
ArrayRef<Value> symbolOperands, Operation *domOpFilter,
1309-
Operation *postDomOpFilter, bool allowNonDereferencingOps,
1310-
bool replaceInDeallocOp) {
1308+
ArrayRef<Value> symbolOperands,
1309+
llvm::function_ref<bool(Operation *)> userFilterFn,
1310+
bool allowNonDereferencingOps, bool replaceInDeallocOp) {
13111311
unsigned newMemRefRank = cast<MemRefType>(newMemRef.getType()).getRank();
13121312
(void)newMemRefRank; // unused in opt mode
13131313
unsigned oldMemRefRank = cast<MemRefType>(oldMemRef.getType()).getRank();
@@ -1328,61 +1328,52 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith(
13281328

13291329
std::unique_ptr<DominanceInfo> domInfo;
13301330
std::unique_ptr<PostDominanceInfo> postDomInfo;
1331-
if (domOpFilter)
1332-
domInfo = std::make_unique<DominanceInfo>(
1333-
domOpFilter->getParentOfType<FunctionOpInterface>());
1334-
1335-
if (postDomOpFilter)
1336-
postDomInfo = std::make_unique<PostDominanceInfo>(
1337-
postDomOpFilter->getParentOfType<FunctionOpInterface>());
13381331

13391332
// Walk all uses of old memref; collect ops to perform replacement. We use a
13401333
// DenseSet since an operation could potentially have multiple uses of a
13411334
// memref (although rare), and the replacement later is going to erase ops.
13421335
DenseSet<Operation *> opsToReplace;
1343-
for (auto *op : oldMemRef.getUsers()) {
1344-
// Skip this use if it's not dominated by domOpFilter.
1345-
if (domOpFilter && !domInfo->dominates(domOpFilter, op))
1346-
continue;
1347-
1348-
// Skip this use if it's not post-dominated by postDomOpFilter.
1349-
if (postDomOpFilter && !postDomInfo->postDominates(postDomOpFilter, op))
1336+
for (auto *user : oldMemRef.getUsers()) {
1337+
// Check if this user doesn't pass the filter.
1338+
if (userFilterFn && !userFilterFn(user))
13501339
continue;
13511340

13521341
// Skip dealloc's - no replacement is necessary, and a memref replacement
13531342
// at other uses doesn't hurt these dealloc's.
1354-
if (hasSingleEffect<MemoryEffects::Free>(op, oldMemRef) &&
1343+
if (hasSingleEffect<MemoryEffects::Free>(user, oldMemRef) &&
13551344
!replaceInDeallocOp)
13561345
continue;
13571346

13581347
// Check if the memref was used in a non-dereferencing context. It is fine
13591348
// for the memref to be used in a non-dereferencing way outside of the
13601349
// region where this replacement is happening.
1361-
if (!isa<AffineMapAccessInterface>(*op)) {
1350+
if (!isa<AffineMapAccessInterface>(*user)) {
13621351
if (!allowNonDereferencingOps) {
1363-
LLVM_DEBUG(llvm::dbgs()
1364-
<< "Memref replacement failed: non-deferencing memref op: \n"
1365-
<< *op << '\n');
1352+
LLVM_DEBUG(
1353+
llvm::dbgs()
1354+
<< "Memref replacement failed: non-deferencing memref user: \n"
1355+
<< *user << '\n');
13661356
return failure();
13671357
}
13681358
// Non-dereferencing ops with the MemRefsNormalizable trait are
13691359
// supported for replacement.
1370-
if (!op->hasTrait<OpTrait::MemRefsNormalizable>()) {
1360+
if (!user->hasTrait<OpTrait::MemRefsNormalizable>()) {
13711361
LLVM_DEBUG(llvm::dbgs() << "Memref replacement failed: use without a "
13721362
"memrefs normalizable trait: \n"
1373-
<< *op << '\n');
1363+
<< *user << '\n');
13741364
return failure();
13751365
}
13761366
}
13771367

1378-
// We'll first collect and then replace --- since replacement erases the op
1379-
// that has the use, and that op could be postDomFilter or domFilter itself!
1380-
opsToReplace.insert(op);
1368+
// We'll first collect and then replace --- since replacement erases the
1369+
// user that has the use, and that user could be postDomFilter or domFilter
1370+
// itself!
1371+
opsToReplace.insert(user);
13811372
}
13821373

1383-
for (auto *op : opsToReplace) {
1374+
for (auto *user : opsToReplace) {
13841375
if (failed(replaceAllMemRefUsesWith(
1385-
oldMemRef, newMemRef, op, extraIndices, indexRemap, extraOperands,
1376+
oldMemRef, newMemRef, user, extraIndices, indexRemap, extraOperands,
13861377
symbolOperands, allowNonDereferencingOps)))
13871378
llvm_unreachable("memref replacement guaranteed to succeed here");
13881379
}
@@ -1763,8 +1754,7 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp allocOp) {
17631754
/*indexRemap=*/layoutMap,
17641755
/*extraOperands=*/{},
17651756
/*symbolOperands=*/symbolOperands,
1766-
/*domOpFilter=*/nullptr,
1767-
/*postDomOpFilter=*/nullptr,
1757+
/*userFilterFn=*/nullptr,
17681758
/*allowNonDereferencingOps=*/true))) {
17691759
// If it failed (due to escapes for example), bail out.
17701760
newAlloc.erase();
@@ -1854,8 +1844,7 @@ mlir::affine::normalizeMemRef(memref::ReinterpretCastOp reinterpretCastOp) {
18541844
/*indexRemap=*/oldLayoutMap,
18551845
/*extraOperands=*/{},
18561846
/*symbolOperands=*/oldStrides,
1857-
/*domOpFilter=*/nullptr,
1858-
/*postDomOpFilter=*/nullptr,
1847+
/*userFilterFn=*/nullptr,
18591848
/*allowNonDereferencingOps=*/true))) {
18601849
// If it failed (due to escapes for example), bail out.
18611850
newReinterpretCast.erase();

mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,7 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp,
297297
/*indexRemap=*/layoutMap,
298298
/*extraOperands=*/{},
299299
/*symbolOperands=*/{},
300-
/*domOpFilter=*/nullptr,
301-
/*postDomOpFilter=*/nullptr,
300+
/*userFilterFn=*/nullptr,
302301
/*allowNonDereferencingOps=*/true,
303302
/*replaceInDeallocOp=*/true))) {
304303
// If it failed (due to escapes for example), bail out.
@@ -407,8 +406,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
407406
/*indexRemap=*/layoutMap,
408407
/*extraOperands=*/{},
409408
/*symbolOperands=*/{},
410-
/*domOpFilter=*/nullptr,
411-
/*postDomOpFilter=*/nullptr,
409+
/*userFilterFn=*/nullptr,
412410
/*allowNonDereferencingOps=*/true,
413411
/*replaceInDeallocOp=*/true))) {
414412
// If it failed (due to escapes for example), bail out. Removing the
@@ -457,8 +455,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp,
457455
/*indexRemap=*/layoutMap,
458456
/*extraOperands=*/{},
459457
/*symbolOperands=*/{},
460-
/*domOpFilter=*/nullptr,
461-
/*postDomOpFilter=*/nullptr,
458+
/*userFilterFn=*/nullptr,
462459
/*allowNonDereferencingOps=*/true,
463460
/*replaceInDeallocOp=*/true))) {
464461
newOp->erase();

mlir/test/Dialect/Affine/affine-data-copy.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -447,3 +447,51 @@ func.func @memref_def_inside(%arg0: index) {
447447
// LIMITED-MEM-NEXT: memref.dealloc %{{.*}} : memref<1xf32>
448448
return
449449
}
450+
451+
// Test with uses across multiple blocks.
452+
453+
memref.global "private" constant @__constant_1x2x1xi32_1 : memref<1x2x1xi32> = dense<0> {alignment = 64 : i64}
454+
455+
// CHECK-LABEL: func @multiple_blocks
456+
func.func @multiple_blocks(%arg0: index) -> memref<1x2x1xi32> {
457+
%c1_i32 = arith.constant 1 : i32
458+
%c3_i32 = arith.constant 3 : i32
459+
%0 = memref.get_global @__constant_1x2x1xi32_1 : memref<1x2x1xi32>
460+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x2x1xi32>
461+
memref.copy %0, %alloc : memref<1x2x1xi32> to memref<1x2x1xi32>
462+
cf.br ^bb1(%alloc : memref<1x2x1xi32>)
463+
^bb1(%1: memref<1x2x1xi32>): // 2 preds: ^bb0, ^bb2
464+
// CHECK: ^bb1(%[[MEM:.*]]: memref<1x2x1xi32>):
465+
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x2x1xi1>
466+
// CHECK: %[[BUF:.*]] = memref.alloc() : memref<1x2x1xi32>
467+
affine.for %arg1 = 0 to 1 {
468+
affine.for %arg2 = 0 to 2 {
469+
affine.for %arg3 = 0 to 1 {
470+
// CHECK: affine.load %[[BUF]]
471+
%3 = affine.load %1[%arg1, %arg2, %arg3] : memref<1x2x1xi32>
472+
%4 = arith.cmpi slt, %3, %c3_i32 : i32
473+
affine.store %4, %alloc_0[%arg1, %arg2, %arg3] : memref<1x2x1xi1>
474+
}
475+
}
476+
}
477+
// CHECK: memref.dealloc %[[BUF]]
478+
%2 = memref.load %alloc_0[%arg0, %arg0, %arg0] : memref<1x2x1xi1>
479+
cf.cond_br %2, ^bb2, ^bb3
480+
^bb2: // pred: ^bb1
481+
// CHECK: ^bb2
482+
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<1x2x1xi32>
483+
affine.for %arg1 = 0 to 1 {
484+
affine.for %arg2 = 0 to 2 {
485+
affine.for %arg3 = 0 to 1 {
486+
// Ensure that this reference isn't replaced.
487+
%3 = affine.load %1[%arg1, %arg2, %arg3] : memref<1x2x1xi32>
488+
// CHECK: affine.load %[[MEM]]
489+
%4 = arith.addi %3, %c1_i32 : i32
490+
affine.store %4, %alloc_1[%arg1, %arg2, %arg3] : memref<1x2x1xi32>
491+
}
492+
}
493+
}
494+
cf.br ^bb1(%alloc_1 : memref<1x2x1xi32>)
495+
^bb3: // pred: ^bb1
496+
return %1 : memref<1x2x1xi32>
497+
}

0 commit comments

Comments
 (0)