Skip to content

Commit f275eda

Browse files
committed
[MLIR][Affine] Fix private memref creation bug in affine fusion
Fix private memref creation bug in affine fusion exposed in the case of the same memref being loaded from/stored to in producer nest. Make the private memref replacement sound. Change affine fusion debug string to affine-fusion - more compact. Fixes: #48703
1 parent 44f638f commit f275eda

File tree

4 files changed

+134
-12
lines changed

4 files changed

+134
-12
lines changed

mlir/include/mlir/Dialect/Affine/Analysis/Utils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,14 @@ FailureOr<AffineValueMap>
610610
simplifyConstrainedMinMaxOp(Operation *op,
611611
FlatAffineValueConstraints constraints);
612612

613+
/// Find the innermost common `Block` of `a` and `b` in the affine scope
614+
/// that `a` and `b` are part of. Return nullptr if they belong to different
615+
/// affine scopes. Also, return null if they do not have a common `Block`
616+
/// ancestor (for eg., when they are part of the `then` and `else` regions
617+
/// of an op that itself starts an affine scope.
618+
mlir::Block *findInnermostCommonBlockInScope(mlir::Operation *a,
619+
mlir::Operation *b);
620+
613621
} // namespace affine
614622
} // namespace mlir
615623

mlir/lib/Dialect/Affine/Analysis/Utils.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "mlir/Dialect/Affine/Analysis/Utils.h"
15+
1516
#include "mlir/Analysis/Presburger/PresburgerRelation.h"
1617
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
1718
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
@@ -2297,3 +2298,40 @@ FailureOr<AffineValueMap> mlir::affine::simplifyConstrainedMinMaxOp(
22972298
affine::canonicalizeMapAndOperands(&newMap, &newOperands);
22982299
return AffineValueMap(newMap, newOperands);
22992300
}
2301+
2302+
Block *mlir::affine::findInnermostCommonBlockInScope(Operation *a,
2303+
Operation *b) {
2304+
Region *aScope = mlir::affine::getAffineScope(a);
2305+
Region *bScope = mlir::affine::getAffineScope(b);
2306+
if (aScope != bScope)
2307+
return nullptr;
2308+
2309+
// Get the block ancestry of `a` while stopping at the affine scope.
2310+
auto getBlockAncestry = [&](Operation *op,
2311+
SmallVectorImpl<Block *> &ancestry) {
2312+
Operation *curOp = op;
2313+
do {
2314+
ancestry.push_back(curOp->getBlock());
2315+
if (curOp->getParentRegion() == aScope)
2316+
break;
2317+
curOp = curOp->getParentOp();
2318+
} while (curOp);
2319+
assert(curOp && "can't reach root op without passing through affine scope");
2320+
std::reverse(ancestry.begin(), ancestry.end());
2321+
};
2322+
2323+
SmallVector<Block *, 4> aAncestors, bAncestors;
2324+
getBlockAncestry(a, aAncestors);
2325+
getBlockAncestry(b, bAncestors);
2326+
assert(!aAncestors.empty() && !bAncestors.empty() &&
2327+
"at least one Block ancestor expected");
2328+
2329+
Block *innermostCommonBlock = nullptr;
2330+
for (unsigned a = 0, b = 0, e = aAncestors.size(), f = bAncestors.size();
2331+
a < e && b < f; ++a, ++b) {
2332+
if (aAncestors[a] != bAncestors[b])
2333+
break;
2334+
innermostCommonBlock = aAncestors[a];
2335+
}
2336+
return innermostCommonBlock;
2337+
}

mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ namespace affine {
4141
} // namespace affine
4242
} // namespace mlir
4343

44-
#define DEBUG_TYPE "affine-loop-fusion"
44+
#define DEBUG_TYPE "affine-fusion"
4545

4646
using namespace mlir;
4747
using namespace mlir::affine;
@@ -237,29 +237,71 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
237237
node->op = newRootForOp;
238238
}
239239

240+
/// Get the operation that should act as a dominance filter while replacing
241+
/// memref uses with a private memref for which `producerStores` and
242+
/// `sliceInsertionBlock` are provided. This effectively determines in what
243+
/// part of the IR we should be performing the replacement.
244+
static Operation *
245+
getDominanceFilterForPrivateMemRefRepl(Block *sliceInsertionBlock,
246+
ArrayRef<Operation *> producerStores) {
247+
assert(!producerStores.empty() && "expected producer store");
248+
249+
// We first find the common block that contains the producer stores and
250+
// the slice computation. The first ancestor among the ancestors of the
251+
// producer stores in that common block is the dominance filter to use for
252+
// replacement.
253+
Block *commonBlock = nullptr;
254+
// Find the common block of all relevant operations.
255+
for (Operation *store : producerStores) {
256+
if (!commonBlock)
257+
commonBlock = findInnermostCommonBlockInScope(
258+
store, &*sliceInsertionBlock->begin());
259+
else
260+
commonBlock =
261+
findInnermostCommonBlockInScope(store, &*commonBlock->begin());
262+
}
263+
assert(commonBlock &&
264+
"common block of producer stores and slice should exist");
265+
266+
// Find the first ancestor among the ancestors of `producerStores` in
267+
// `commonBlock`.
268+
Operation *firstAncestor = nullptr;
269+
for (Operation *store : producerStores) {
270+
Operation *ancestor = commonBlock->findAncestorOpInBlock(*store);
271+
assert(ancestor && "producer store should be contained in common block");
272+
firstAncestor = !firstAncestor || ancestor->isBeforeInBlock(firstAncestor)
273+
? ancestor
274+
: firstAncestor;
275+
}
276+
return firstAncestor;
277+
}
278+
240279
// Creates and returns a private (single-user) memref for fused loop rooted
241280
// at 'forOp', with (potentially reduced) memref size based on the
242281
// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
243282
// TODO: consider refactoring the common code from generateDma and
244283
// this one.
245-
static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
284+
static Value createPrivateMemRef(AffineForOp forOp,
285+
ArrayRef<Operation *> storeOps,
246286
unsigned dstLoopDepth,
247287
std::optional<unsigned> fastMemorySpace,
288+
Block *sliceInsertionBlock,
248289
uint64_t localBufSizeThreshold) {
249-
Operation *forInst = forOp.getOperation();
290+
assert(!storeOps.empty() && "no source stores supplied");
291+
Operation *srcStoreOp = storeOps[0];
250292

251293
// Create builder to insert alloc op just before 'forOp'.
252-
OpBuilder b(forInst);
294+
OpBuilder b(forOp);
253295
// Builder to create constants at the top level.
254-
OpBuilder top(forInst->getParentRegion());
296+
OpBuilder top(forOp->getParentRegion());
255297
// Create new memref type based on slice bounds.
256-
auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef();
298+
auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOp).getMemRef();
257299
auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
258300
unsigned rank = oldMemRefType.getRank();
259301

260302
// Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
261-
MemRefRegion region(srcStoreOpInst->getLoc());
262-
bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth));
303+
MemRefRegion region(srcStoreOp->getLoc());
304+
bool validRegion = succeeded(region.compute(srcStoreOp, dstLoopDepth));
263305
(void)validRegion;
264306
assert(validRegion && "unexpected memref region failure");
265307
SmallVector<int64_t, 4> newShape;
@@ -332,11 +374,12 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
332374
AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());
333375

334376
// Replace all users of 'oldMemRef' with 'newMemRef'.
377+
Operation *domFilter =
378+
getDominanceFilterForPrivateMemRefRepl(sliceInsertionBlock, storeOps);
335379
LogicalResult res =
336380
replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
337381
/*extraOperands=*/outerIVs,
338-
/*symbolOperands=*/{},
339-
/*domOpFilter=*/&*forOp.getBody()->begin());
382+
/*symbolOperands=*/{}, domFilter);
340383
assert(succeeded(res) &&
341384
"replaceAllMemrefUsesWith should always succeed here");
342385
(void)res;
@@ -944,6 +987,10 @@ struct GreedyFusion {
944987

945988
// Create private memrefs.
946989
if (!privateMemrefs.empty()) {
990+
// Note the block into which fusion was performed. This can be used to
991+
// place `alloc`s that create private memrefs.
992+
Block *sliceInsertionBlock = bestSlice.insertPoint->getBlock();
993+
947994
// Gather stores for all the private-to-be memrefs.
948995
DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores;
949996
dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
@@ -962,8 +1009,8 @@ struct GreedyFusion {
9621009
SmallVector<Operation *, 4> &storesForMemref =
9631010
memrefToStoresPair.second;
9641011
Value newMemRef = createPrivateMemRef(
965-
dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
966-
fastMemorySpace, localBufSizeThreshold);
1012+
dstAffineForOp, storesForMemref, bestDstLoopDepth,
1013+
fastMemorySpace, sliceInsertionBlock, localBufSizeThreshold);
9671014
// Create new node in dependence graph for 'newMemRef' alloc op.
9681015
unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
9691016
// Add edge from 'newMemRef' node to dstNode.

mlir/test/Dialect/Affine/loop-fusion-4.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,3 +285,32 @@ module {
285285
spirv.ReturnValue %3 : !spirv.array<8192 x f32>
286286
}
287287
}
288+
289+
// -----
290+
291+
// PRODUCER-CONSUMER-LABEL: func @same_memref_load_store
292+
func.func @same_memref_load_store(%producer : memref<32xf32>, %consumer: memref<16xf32>){
293+
%cst = arith.constant 2.000000e+00 : f32
294+
// Source isn't removed.
295+
// PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 32
296+
affine.for %arg3 = 0 to 32 {
297+
%0 = affine.load %producer[%arg3] : memref<32xf32>
298+
%2 = arith.mulf %0, %cst : f32
299+
affine.store %2, %producer[%arg3] : memref<32xf32>
300+
}
301+
affine.for %arg3 = 0 to 16 {
302+
%0 = affine.load %producer[%arg3] : memref<32xf32>
303+
%2 = arith.addf %0, %cst : f32
304+
affine.store %2, %consumer[%arg3] : memref<16xf32>
305+
}
306+
// Fused nest.
307+
// PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 16
308+
// PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<32xf32>
309+
// PRODUCER-CONSUMER-NEXT: arith.mulf
310+
// PRODUCER-CONSUMER-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
311+
// PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
312+
// PRODUCER-CONSUMER-NEXT: arith.addf
313+
// PRODUCER-CONSUMER-NEXT: affine.store
314+
// PRODUCER-CONSUMER-NEXT: }
315+
return
316+
}

0 commit comments

Comments
 (0)