Skip to content

Commit 3e4be55

Browse files
authored
[MLIR][Affine] Improve sibling fusion - handle memrefs from memref defining nodes (llvm#149641)
Improve sibling fusion - handle memrefs from memref defining nodes which were not being considered. Remove the unnecessary restriction from MDG memref edge iteration to restrict to affine.for ops. Nodes in the MDG could be other ops as well. Fixes: llvm#61825
1 parent 42017c6 commit 3e4be55

File tree

3 files changed

+34
-7
lines changed

3 files changed

+34
-7
lines changed

mlir/lib/Dialect/Affine/Analysis/Utils.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -710,15 +710,15 @@ void MemRefDependenceGraph::clearNodeLoadAndStores(unsigned id) {
710710
void MemRefDependenceGraph::forEachMemRefInputEdge(
711711
unsigned id, const std::function<void(Edge)> &callback) {
712712
if (inEdges.count(id) > 0)
713-
forEachMemRefEdge(inEdges[id], callback);
713+
forEachMemRefEdge(inEdges.at(id), callback);
714714
}
715715

716716
// Calls 'callback' for each output edge from node 'id' which carries a
717717
// memref dependence.
718718
void MemRefDependenceGraph::forEachMemRefOutputEdge(
719719
unsigned id, const std::function<void(Edge)> &callback) {
720720
if (outEdges.count(id) > 0)
721-
forEachMemRefEdge(outEdges[id], callback);
721+
forEachMemRefEdge(outEdges.at(id), callback);
722722
}
723723

724724
// Calls 'callback' for each edge in 'edges' which carries a memref
@@ -730,9 +730,6 @@ void MemRefDependenceGraph::forEachMemRefEdge(
730730
if (!isa<MemRefType>(edge.value.getType()))
731731
continue;
732732
assert(nodes.count(edge.id) > 0);
733-
// Skip if 'edge.id' is not a loop nest.
734-
if (!isa<AffineForOp>(getNode(edge.id)->op))
735-
continue;
736733
// Visit current input edge 'edge'.
737734
callback(edge);
738735
}

mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1473,9 +1473,11 @@ struct GreedyFusion {
14731473
SmallVector<MemRefDependenceGraph::Edge, 2> inEdges;
14741474
mdg->forEachMemRefInputEdge(
14751475
dstNode->id, [&](MemRefDependenceGraph::Edge inEdge) {
1476-
// Add 'inEdge' if it is a read-after-write dependence.
1476+
// Add 'inEdge' if it is a read-after-write dependence or an edge
1477+
// from a memref defining op (e.g. view-like op or alloc op).
14771478
if (dstNode->getLoadOpCount(inEdge.value) > 0 &&
1478-
mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0)
1479+
(mdg->getNode(inEdge.id)->getStoreOpCount(inEdge.value) > 0 ||
1480+
inEdge.value.getDefiningOp() == mdg->getNode(inEdge.id)->op))
14791481
inEdges.push_back(inEdge);
14801482
});
14811483

mlir/test/Dialect/Affine/loop-fusion-4.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -743,3 +743,31 @@ module {
743743
return
744744
}
745745
}
746+
747+
// SIBLING-MAXIMAL-LABEL: memref_cast_reused
748+
func.func @memref_cast_reused(%arg: memref<*xf32>) {
749+
%alloc = memref.cast %arg : memref<*xf32> to memref<10xf32>
750+
%alloc_0 = memref.alloc() : memref<10xf32>
751+
%alloc_1 = memref.alloc() : memref<10xf32>
752+
%cst = arith.constant 0.000000e+00 : f32
753+
%cst_2 = arith.constant 1.000000e+00 : f32
754+
affine.for %arg0 = 0 to 10 {
755+
%0 = affine.load %alloc[%arg0] : memref<10xf32>
756+
%1 = arith.addf %0, %cst_2 : f32
757+
affine.store %1, %alloc_0[%arg0] : memref<10xf32>
758+
}
759+
affine.for %arg0 = 0 to 10 {
760+
%0 = affine.load %alloc[%arg0] : memref<10xf32>
761+
%1 = affine.load %alloc_1[0] : memref<10xf32>
762+
%2 = arith.addf %0, %1 : f32
763+
affine.store %2, %alloc_1[0] : memref<10xf32>
764+
}
765+
// SIBLING-MAXIMAL: affine.for %{{.*}} = 0 to 10
766+
// SIBLING-MAXIMAL: addf
767+
// SIBLING-MAXIMAL-NEXT: affine.store
768+
// SIBLING-MAXIMAL-NEXT: affine.load
769+
// SIBLING-MAXIMAL-NEXT: affine.load
770+
// SIBLING-MAXIMAL-NEXT: addf
771+
// SIBLING-MAXIMAL-NEXT: affine.store
772+
return
773+
}

0 commit comments

Comments
 (0)