Skip to content

Commit a38b8ab

Browse files
committed
[MLIR][Affine] Fix fusion private memref creation for multiple producer stores
Fix private memref creation in affine fusion for the multiple producer store case. This scenario was not supported but not properly checked. Fixes: #120227
1 parent ec54ec6 commit a38b8ab

File tree

3 files changed

+80
-11
lines changed

3 files changed

+80
-11
lines changed

mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -328,14 +328,34 @@ static std::optional<double> getAdditionalComputeFraction(
328328
// Creates and returns a private (single-user) memref for fused loop rooted at
329329
// 'forOp', with (potentially reduced) memref size based on the memref region
330330
// written to by `storeOps` at depth 'dstLoopDepth'. 'sliceInsertionBlock'
331-
// specifies the block in which the slice was/will be inserted.
331+
// specifies the block in which the slice was/will be inserted. The method
332+
// expects that all stores ops to the memref have the same access function.
333+
// Returns nullptr if the creation failed.
332334
static Value createPrivateMemRef(AffineForOp forOp,
333335
ArrayRef<Operation *> storeOps,
334336
unsigned dstLoopDepth,
335337
std::optional<unsigned> fastMemorySpace,
336338
Block *sliceInsertionBlock,
337339
uint64_t localBufSizeThreshold) {
338340
assert(!storeOps.empty() && "no source stores supplied");
341+
342+
// Check if all stores have the same access function; we only support this
343+
// case.
344+
// TODO: Use union of memref write regions to compute private memref footprint
345+
// for store ops with different access functions.
346+
if (storeOps.size() > 1 &&
347+
!std::equal(std::next(storeOps.begin()), storeOps.end(), storeOps.begin(),
348+
[](Operation *a, Operation *b) {
349+
MemRefAccess aM(cast<AffineWriteOpInterface>(a));
350+
MemRefAccess bM(cast<AffineWriteOpInterface>(b));
351+
return aM == bM;
352+
})) {
353+
LLVM_DEBUG(llvm::dbgs()
354+
<< "Private memref creation unsupported for multiple producer "
355+
"stores with different access functions.\n");
356+
return nullptr;
357+
}
358+
339359
Operation *srcStoreOp = storeOps[0];
340360

341361
// Create builder to insert alloc op just before 'forOp'.
@@ -432,6 +452,8 @@ static Value createPrivateMemRef(AffineForOp forOp,
432452
assert(succeeded(res) &&
433453
"replaceAllMemrefUsesWith should always succeed here");
434454
(void)res;
455+
LLVM_DEBUG(llvm::dbgs() << "Created private memref of type: " << newMemRefType
456+
<< '\n');
435457
return newMemRef;
436458
}
437459

@@ -1123,13 +1145,12 @@ struct GreedyFusion {
11231145
// loads and stores. Any reference to the original ones becomes
11241146
// invalid after this point.
11251147
for (auto &memrefToStoresPair : privateMemRefToStores) {
1126-
// TODO: Use union of memref write regions to compute
1127-
// private memref footprint.
1128-
SmallVector<Operation *, 4> &storesForMemref =
1129-
memrefToStoresPair.second;
1148+
ArrayRef<Operation *> storesForMemref = memrefToStoresPair.second;
11301149
Value newMemRef = createPrivateMemRef(
11311150
dstAffineForOp, storesForMemref, bestDstLoopDepth,
11321151
fastMemorySpace, sliceInsertionBlock, localBufSizeThreshold);
1152+
if (!newMemRef)
1153+
continue;
11331154
// Create new node in dependence graph for 'newMemRef' alloc op.
11341155
unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
11351156
// Add edge from 'newMemRef' node to dstNode.

mlir/test/Dialect/Affine/loop-fusion-2.mlir

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,20 @@ func.func @should_fuse_at_depth_above_loop_carried_dependence(%arg0: memref<64x4
3939

4040
// We can fuse source loop nest '%i0' into dst loop nest '%i2', but the
4141
// depth at which we can insert the src loop nest slice into the dst loop
42-
// lest must be decreased because of a loop carried dependence on loop '%i3'.
42+
// nest must be decreased because of a loop carried dependence on loop '%i3'.
4343
// As a result, the source loop nest is inserted at dst loop nest depth 1,
4444
// just above the loop with the carried dependence. In addition, the source
4545
// loop nest iteration bounds on its loop '%i1' are reduced to 1, so the
46-
// memref size can be reduced to 128x1xf32.
46+
// memref size can be reduced to 64x1xf32.
4747

48-
// CHECK: memref.alloc() : memref<64x1xf32>
48+
// In this case, since we have multiple producer stores (for %out) with
49+
// different access functions and we don't yet support private memref
50+
// computation in such cases, a 64x1 private memref isn't created.
51+
52+
// CHECK: memref.alloc() : memref<64x4xf32>
4953
// CHECK: affine.for %{{.*}} = 0 to 4 {
5054
// CHECK-NEXT: affine.for %{{.*}} = 0 to 64 {
51-
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, 0] : memref<64x1xf32>
55+
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<64x4xf32>
5256
// CHECK-NEXT: }
5357
// CHECK-NEXT: affine.for %{{.*}} = 0 to 4 {
5458
// CHECK-NEXT: affine.for %{{.*}} = 0 to 16 {
@@ -62,9 +66,9 @@ func.func @should_fuse_at_depth_above_loop_carried_dependence(%arg0: memref<64x4
6266
// CHECK-NEXT: }
6367
// CHECK-NEXT: affine.for %{{.*}} = 0 to 16 {
6468
// CHECK-NEXT: %{{.*}} = "op2"() : () -> f32
65-
// CHECK: affine.load %{{.*}}[%{{.*}} * 16 + %{{.*}}, 0] : memref<64x1xf32>
69+
// CHECK: affine.load %{{.*}}[%{{.*}} * 16 + %{{.*}}, %{{.*}}] : memref<64x4xf32>
6670
// CHECK-NEXT: arith.addf %{{.*}}, %{{.*}} : f32
67-
// CHECK: affine.store %{{.*}}, %{{.*}}[%{{.*}} * 16 + %{{.*}}, 0] : memref<64x1xf32>
71+
// CHECK: affine.store %{{.*}}, %{{.*}}[%{{.*}} * 16 + %{{.*}}, %{{.*}}] : memref<64x4xf32>
6872
// CHECK-NEXT: }
6973
// CHECK-NEXT: }
7074
// CHECK-NEXT: }

mlir/test/Dialect/Affine/loop-fusion-4.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,3 +622,47 @@ func.func @zero_tolerance(%arg0: memref<65536xcomplex<f64>>, %arg1: memref<30x13
622622
}
623623
func.func private @__external_levelwise_forward_ntt(memref<30x131072xi64>)
624624
func.func private @__external_reduce_barrett(i64, i64, i64, i64, i128) -> i64
625+
626+
// An unrolled loop nest. Fusion here should correctly fuse while preserving
627+
// dependences between store-load pairs of the same memref. A private memref
628+
// of size 1x1x1 can't be created.
629+
630+
// PRODUCER-CONSUMER-MAXIMAL-LABEL: func @unrolled
631+
func.func @unrolled(%arg0: memref<2x4xf32>, %arg1: memref<1x2x4xf32>) {
632+
%alloc = memref.alloc() : memref<1x2x4xf32>
633+
affine.for %i = 0 to 1 {
634+
%0 = affine.load %arg0[0, 0] : memref<2x4xf32>
635+
%1 = affine.load %arg0[0, 1] : memref<2x4xf32>
636+
%2 = affine.load %arg0[0, 2] : memref<2x4xf32>
637+
%3 = affine.load %arg0[0, 3] : memref<2x4xf32>
638+
%4 = affine.load %arg0[1, 0] : memref<2x4xf32>
639+
%5 = affine.load %arg0[1, 1] : memref<2x4xf32>
640+
%6 = affine.load %arg0[1, 2] : memref<2x4xf32>
641+
%7 = affine.load %arg0[1, 3] : memref<2x4xf32>
642+
643+
affine.store %0, %alloc[0, 0, 0] : memref<1x2x4xf32>
644+
affine.store %1, %alloc[0, 0, 1] : memref<1x2x4xf32>
645+
affine.store %2, %alloc[0, 0, 2] : memref<1x2x4xf32>
646+
affine.store %3, %alloc[0, 0, 3] : memref<1x2x4xf32>
647+
affine.store %4, %alloc[0, 1, 0] : memref<1x2x4xf32>
648+
affine.store %5, %alloc[0, 1, 1] : memref<1x2x4xf32>
649+
affine.store %6, %alloc[0, 1, 2] : memref<1x2x4xf32>
650+
affine.store %7, %alloc[0, 1, 3] : memref<1x2x4xf32>
651+
}
652+
653+
affine.for %i = 0 to 2 {
654+
affine.for %j = 0 to 4 {
655+
%8 = affine.load %alloc[0, %i, %j] : memref<1x2x4xf32>
656+
%9 = arith.negf %8 : f32
657+
affine.store %9, %arg1[0, %i, %j] : memref<1x2x4xf32>
658+
}
659+
}
660+
// PRODUCER-CONSUMER-MAXIMAL: affine.for %{{.*}} = 0 to 2 {
661+
// PRODUCER-CONSUMER-MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 4 {
662+
// PRODUCER-CONSUMER-MAXIMAL-NEXT: affine.load %{{.*}}[0, 0]
663+
// PRODUCER-CONSUMER-MAXIMAL: affine.load %{{.*}}[1, 3]
664+
// PRODUCER-CONSUMER-MAXIMAL: affine.store %{{.*}}[0, 0, 0]
665+
// PRODUCER-CONSUMER-MAXIMAL: affine.store %{{.*}}[0, 1, 3]
666+
// PRODUCER-CONSUMER-MAXIMAL: affine.load %{{.*}}[0, %{{.*}}, %{{.*}}]
667+
return
668+
}

0 commit comments

Comments
 (0)