Skip to content

Commit 191eb97

Browse files
committed
[MLIR][Affine] Make affine fusion MDG API const correct
Make affine fusion MDG API const correct. NFC changes otherwise.
1 parent 44f638f commit 191eb97

File tree

3 files changed

+84
-66
lines changed

3 files changed

+84
-66
lines changed

mlir/include/mlir/Dialect/Affine/Analysis/Utils.h

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,11 @@ struct MemRefDependenceGraph {
139139

140140
// Map from node id to Node.
141141
DenseMap<unsigned, Node> nodes;
142-
// Map from node id to list of input edges.
142+
// Map from node id to list of input edges. The absence of an entry for a key
143+
// is also equivalent to the absence of any edges.
143144
DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
144-
// Map from node id to list of output edges.
145+
// Map from node id to list of output edges. The absence of an entry for a
146+
// node is also equivalent to the absence of any edges.
145147
DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
146148
// Map from memref to a count on the dependence edges associated with that
147149
// memref.
@@ -156,10 +158,21 @@ struct MemRefDependenceGraph {
156158
bool init();
157159

158160
// Returns the graph node for 'id'.
159-
Node *getNode(unsigned id);
161+
const Node *getNode(unsigned id) const;
162+
Node *getNode(unsigned id) {
163+
return const_cast<Node *>(
164+
static_cast<const MemRefDependenceGraph *>(this)->getNode(id));
165+
}
166+
167+
// Returns true if the graph has node with ID `id`.
168+
bool hasNode(unsigned id) const { return nodes.contains(id); }
160169

161170
// Returns the graph node for 'forOp'.
162-
Node *getForOpNode(AffineForOp forOp);
171+
const Node *getForOpNode(AffineForOp forOp) const;
172+
Node *getForOpNode(AffineForOp forOp) {
173+
return const_cast<Node *>(
174+
static_cast<const MemRefDependenceGraph *>(this)->getForOpNode(forOp));
175+
}
163176

164177
// Adds a node with 'op' to the graph and returns its unique identifier.
165178
unsigned addNode(Operation *op);
@@ -169,12 +182,12 @@ struct MemRefDependenceGraph {
169182

170183
// Returns true if node 'id' writes to any memref which escapes (or is an
171184
// argument to) the block. Returns false otherwise.
172-
bool writesToLiveInOrEscapingMemrefs(unsigned id);
185+
bool writesToLiveInOrEscapingMemrefs(unsigned id) const;
173186

174187
// Returns true iff there is an edge from node 'srcId' to node 'dstId' which
175188
// is for 'value' if non-null, or for any value otherwise. Returns false
176189
// otherwise.
177-
bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr);
190+
bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr) const;
178191

179192
// Adds an edge from node 'srcId' to node 'dstId' for 'value'.
180193
void addEdge(unsigned srcId, unsigned dstId, Value value);
@@ -185,23 +198,25 @@ struct MemRefDependenceGraph {
185198
// Returns true if there is a path in the dependence graph from node 'srcId'
186199
// to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
187200
// operations that the edges connected are expected to be from the same block.
188-
bool hasDependencePath(unsigned srcId, unsigned dstId);
201+
bool hasDependencePath(unsigned srcId, unsigned dstId) const;
189202

190203
// Returns the input edge count for node 'id' and 'memref' from src nodes
191204
// which access 'memref' with a store operation.
192-
unsigned getIncomingMemRefAccesses(unsigned id, Value memref);
205+
unsigned getIncomingMemRefAccesses(unsigned id, Value memref) const;
193206

194207
// Returns the output edge count for node 'id' and 'memref' (if non-null),
195208
// otherwise returns the total output edge count from node 'id'.
196-
unsigned getOutEdgeCount(unsigned id, Value memref = nullptr);
209+
unsigned getOutEdgeCount(unsigned id, Value memref = nullptr) const;
197210

198211
/// Return all nodes which define SSA values used in node 'id'.
199-
void gatherDefiningNodes(unsigned id, DenseSet<unsigned> &definingNodes);
212+
void gatherDefiningNodes(unsigned id,
213+
DenseSet<unsigned> &definingNodes) const;
200214

201215
// Computes and returns an insertion point operation, before which the
202216
// the fused <srcId, dstId> loop nest can be inserted while preserving
203217
// dependences. Returns nullptr if no such insertion point is found.
204-
Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId);
218+
Operation *getFusedLoopNestInsertionPoint(unsigned srcId,
219+
unsigned dstId) const;
205220

206221
// Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
207222
// taking into account that:

mlir/lib/Dialect/Affine/Analysis/Utils.cpp

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,9 @@ static void getEffectedValues(Operation *op, SmallVectorImpl<Value> &values) {
187187

188188
/// Add `op` to MDG creating a new node and adding its memory accesses (affine
189189
/// or non-affine to memrefAccesses (memref -> list of nodes with accesses) map.
190-
Node *addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
191-
DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
190+
static Node *
191+
addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg,
192+
DenseMap<Value, SetVector<unsigned>> &memrefAccesses) {
192193
auto &nodes = mdg.nodes;
193194
// Create graph node 'id' to represent top-level 'forOp' and record
194195
// all loads and store accesses it contains.
@@ -358,14 +359,14 @@ bool MemRefDependenceGraph::init() {
358359
}
359360

360361
// Returns the graph node for 'id'.
361-
Node *MemRefDependenceGraph::getNode(unsigned id) {
362+
const Node *MemRefDependenceGraph::getNode(unsigned id) const {
362363
auto it = nodes.find(id);
363364
assert(it != nodes.end());
364365
return &it->second;
365366
}
366367

367368
// Returns the graph node for 'forOp'.
368-
Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) {
369+
const Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) const {
369370
for (auto &idAndNode : nodes)
370371
if (idAndNode.second.op == forOp)
371372
return &idAndNode.second;
@@ -389,7 +390,7 @@ void MemRefDependenceGraph::removeNode(unsigned id) {
389390
}
390391
}
391392
// Remove each edge in 'outEdges[id]'.
392-
if (outEdges.count(id) > 0) {
393+
if (outEdges.contains(id)) {
393394
SmallVector<Edge, 2> oldOutEdges = outEdges[id];
394395
for (auto &outEdge : oldOutEdges) {
395396
removeEdge(id, outEdge.id, outEdge.value);
@@ -403,8 +404,8 @@ void MemRefDependenceGraph::removeNode(unsigned id) {
403404

404405
// Returns true if node 'id' writes to any memref which escapes (or is an
405406
// argument to) the block. Returns false otherwise.
406-
bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
407-
Node *node = getNode(id);
407+
bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) const {
408+
const Node *node = getNode(id);
408409
for (auto *storeOpInst : node->stores) {
409410
auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
410411
auto *op = memref.getDefiningOp();
@@ -424,14 +425,14 @@ bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
424425
// is for 'value' if non-null, or for any value otherwise. Returns false
425426
// otherwise.
426427
bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId,
427-
Value value) {
428-
if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
428+
Value value) const {
429+
if (!outEdges.contains(srcId) || !inEdges.contains(dstId)) {
429430
return false;
430431
}
431-
bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
432+
bool hasOutEdge = llvm::any_of(outEdges.lookup(srcId), [=](const Edge &edge) {
432433
return edge.id == dstId && (!value || edge.value == value);
433434
});
434-
bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
435+
bool hasInEdge = llvm::any_of(inEdges.lookup(dstId), [=](const Edge &edge) {
435436
return edge.id == srcId && (!value || edge.value == value);
436437
});
437438
return hasOutEdge && hasInEdge;
@@ -476,7 +477,8 @@ void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId,
476477
// Returns true if there is a path in the dependence graph from node 'srcId'
477478
// to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
478479
// operations that the edges connected are expected to be from the same block.
479-
bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
480+
bool MemRefDependenceGraph::hasDependencePath(unsigned srcId,
481+
unsigned dstId) const {
480482
// Worklist state is: <node-id, next-output-edge-index-to-visit>
481483
SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
482484
worklist.push_back({srcId, 0});
@@ -489,13 +491,13 @@ bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
489491
return true;
490492
// Pop and continue if node has no out edges, or if all out edges have
491493
// already been visited.
492-
if (outEdges.count(idAndIndex.first) == 0 ||
493-
idAndIndex.second == outEdges[idAndIndex.first].size()) {
494+
if (!outEdges.contains(idAndIndex.first) ||
495+
idAndIndex.second == outEdges.lookup(idAndIndex.first).size()) {
494496
worklist.pop_back();
495497
continue;
496498
}
497499
// Get graph edge to traverse.
498-
Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
500+
const Edge edge = outEdges.lookup(idAndIndex.first)[idAndIndex.second];
499501
// Increment next output edge index for 'idAndIndex'.
500502
++idAndIndex.second;
501503
// Add node at 'edge.id' to the worklist. We don't need to consider
@@ -511,34 +513,34 @@ bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
511513
// Returns the input edge count for node 'id' and 'memref' from src nodes
512514
// which access 'memref' with a store operation.
513515
unsigned MemRefDependenceGraph::getIncomingMemRefAccesses(unsigned id,
514-
Value memref) {
516+
Value memref) const {
515517
unsigned inEdgeCount = 0;
516-
if (inEdges.count(id) > 0)
517-
for (auto &inEdge : inEdges[id])
518-
if (inEdge.value == memref) {
519-
Node *srcNode = getNode(inEdge.id);
520-
// Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
521-
if (srcNode->getStoreOpCount(memref) > 0)
522-
++inEdgeCount;
523-
}
518+
for (const Edge &inEdge : inEdges.lookup(id)) {
519+
if (inEdge.value == memref) {
520+
const Node *srcNode = getNode(inEdge.id);
521+
// Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
522+
if (srcNode->getStoreOpCount(memref) > 0)
523+
++inEdgeCount;
524+
}
525+
}
524526
return inEdgeCount;
525527
}
526528

527529
// Returns the output edge count for node 'id' and 'memref' (if non-null),
528530
// otherwise returns the total output edge count from node 'id'.
529-
unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, Value memref) {
531+
unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id,
532+
Value memref) const {
530533
unsigned outEdgeCount = 0;
531-
if (outEdges.count(id) > 0)
532-
for (auto &outEdge : outEdges[id])
533-
if (!memref || outEdge.value == memref)
534-
++outEdgeCount;
534+
for (const auto &outEdge : outEdges.lookup(id))
535+
if (!memref || outEdge.value == memref)
536+
++outEdgeCount;
535537
return outEdgeCount;
536538
}
537539

538540
/// Return all nodes which define SSA values used in node 'id'.
539541
void MemRefDependenceGraph::gatherDefiningNodes(
540-
unsigned id, DenseSet<unsigned> &definingNodes) {
541-
for (MemRefDependenceGraph::Edge edge : inEdges[id])
542+
unsigned id, DenseSet<unsigned> &definingNodes) const {
543+
for (const Edge &edge : inEdges.lookup(id))
542544
// By definition of edge, if the edge value is a non-memref value,
543545
// then the dependence is between a graph node which defines an SSA value
544546
// and another graph node which uses the SSA value.
@@ -551,8 +553,8 @@ void MemRefDependenceGraph::gatherDefiningNodes(
551553
// dependences. Returns nullptr if no such insertion point is found.
552554
Operation *
553555
MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
554-
unsigned dstId) {
555-
if (outEdges.count(srcId) == 0)
556+
unsigned dstId) const {
557+
if (!outEdges.contains(srcId))
556558
return getNode(dstId)->op;
557559

558560
// Skip if there is any defining node of 'dstId' that depends on 'srcId'.
@@ -568,13 +570,13 @@ MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
568570

569571
// Build set of insts in range (srcId, dstId) which depend on 'srcId'.
570572
SmallPtrSet<Operation *, 2> srcDepInsts;
571-
for (auto &outEdge : outEdges[srcId])
573+
for (auto &outEdge : outEdges.lookup(srcId))
572574
if (outEdge.id != dstId)
573575
srcDepInsts.insert(getNode(outEdge.id)->op);
574576

575577
// Build set of insts in range (srcId, dstId) on which 'dstId' depends.
576578
SmallPtrSet<Operation *, 2> dstDepInsts;
577-
for (auto &inEdge : inEdges[dstId])
579+
for (auto &inEdge : inEdges.lookup(dstId))
578580
if (inEdge.id != srcId)
579581
dstDepInsts.insert(getNode(inEdge.id)->op);
580582

@@ -634,7 +636,7 @@ void MemRefDependenceGraph::updateEdges(unsigned srcId, unsigned dstId,
634636
SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
635637
for (auto &inEdge : oldInEdges) {
636638
// Add edge from 'inEdge.id' to 'dstId' if it's not a private memref.
637-
if (privateMemRefs.count(inEdge.value) == 0)
639+
if (!privateMemRefs.contains(inEdge.value))
638640
addEdge(inEdge.id, dstId, inEdge.value);
639641
}
640642
}

mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,13 @@ struct LoopFusion : public affine::impl::AffineLoopFusionBase<LoopFusion> {
7878
static bool canRemoveSrcNodeAfterFusion(
7979
unsigned srcId, unsigned dstId, const ComputationSliceState &fusionSlice,
8080
Operation *fusedLoopInsPoint, const DenseSet<Value> &escapingMemRefs,
81-
MemRefDependenceGraph *mdg) {
81+
const MemRefDependenceGraph &mdg) {
8282

83-
Operation *dstNodeOp = mdg->getNode(dstId)->op;
83+
Operation *dstNodeOp = mdg.getNode(dstId)->op;
8484
bool hasOutDepsAfterFusion = false;
8585

86-
for (auto &outEdge : mdg->outEdges[srcId]) {
87-
Operation *depNodeOp = mdg->getNode(outEdge.id)->op;
86+
for (auto &outEdge : mdg.outEdges.lookup(srcId)) {
87+
Operation *depNodeOp = mdg.getNode(outEdge.id)->op;
8888
// Skip dependence with dstOp since it will be removed after fusion.
8989
if (depNodeOp == dstNodeOp)
9090
continue;
@@ -134,22 +134,23 @@ static bool canRemoveSrcNodeAfterFusion(
134134
/// held if the 'mdg' is reused from a previous fusion step or if the node
135135
/// creation order changes in the future to support more advance cases.
136136
// TODO: Move this to a loop fusion utility once 'mdg' is also moved.
137-
static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg,
137+
static void getProducerCandidates(unsigned dstId,
138+
const MemRefDependenceGraph &mdg,
138139
SmallVectorImpl<unsigned> &srcIdCandidates) {
139140
// Skip if no input edges along which to fuse.
140-
if (mdg->inEdges.count(dstId) == 0)
141+
if (mdg.inEdges.count(dstId) == 0)
141142
return;
142143

143144
// Gather memrefs from loads in 'dstId'.
144-
auto *dstNode = mdg->getNode(dstId);
145+
auto *dstNode = mdg.getNode(dstId);
145146
DenseSet<Value> consumedMemrefs;
146147
for (Operation *load : dstNode->loads)
147148
consumedMemrefs.insert(cast<AffineReadOpInterface>(load).getMemRef());
148149

149150
// Traverse 'dstId' incoming edges and gather the nodes that contain a store
150151
// to one of the consumed memrefs.
151-
for (auto &srcEdge : mdg->inEdges[dstId]) {
152-
auto *srcNode = mdg->getNode(srcEdge.id);
152+
for (const auto &srcEdge : mdg.inEdges.lookup(dstId)) {
153+
const auto *srcNode = mdg.getNode(srcEdge.id);
153154
// Skip if 'srcNode' is not a loop nest.
154155
if (!isa<AffineForOp>(srcNode->op))
155156
continue;
@@ -169,10 +170,10 @@ static void getProducerCandidates(unsigned dstId, MemRefDependenceGraph *mdg,
169170
/// producer-consumer dependence between 'srcId' and 'dstId'.
170171
static void
171172
gatherProducerConsumerMemrefs(unsigned srcId, unsigned dstId,
172-
MemRefDependenceGraph *mdg,
173+
const MemRefDependenceGraph &mdg,
173174
DenseSet<Value> &producerConsumerMemrefs) {
174-
auto *dstNode = mdg->getNode(dstId);
175-
auto *srcNode = mdg->getNode(srcId);
175+
auto *dstNode = mdg.getNode(dstId);
176+
auto *srcNode = mdg.getNode(srcId);
176177
gatherProducerConsumerMemrefs(srcNode->stores, dstNode->loads,
177178
producerConsumerMemrefs);
178179
}
@@ -214,14 +215,14 @@ static bool isEscapingMemref(Value memref, Block *block) {
214215

215216
/// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
216217
/// that escape the block or are accessed in a non-affine way.
217-
static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
218+
static void gatherEscapingMemrefs(unsigned id, const MemRefDependenceGraph &mdg,
218219
DenseSet<Value> &escapingMemRefs) {
219-
auto *node = mdg->getNode(id);
220+
auto *node = mdg.getNode(id);
220221
for (Operation *storeOp : node->stores) {
221222
auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
222223
if (escapingMemRefs.count(memref))
223224
continue;
224-
if (isEscapingMemref(memref, &mdg->block))
225+
if (isEscapingMemref(memref, &mdg.block))
225226
escapingMemRefs.insert(memref);
226227
}
227228
}
@@ -787,7 +788,7 @@ struct GreedyFusion {
787788
// in 'srcIdCandidates'.
788789
dstNodeChanged = false;
789790
SmallVector<unsigned, 16> srcIdCandidates;
790-
getProducerCandidates(dstId, mdg, srcIdCandidates);
791+
getProducerCandidates(dstId, *mdg, srcIdCandidates);
791792

792793
for (unsigned srcId : llvm::reverse(srcIdCandidates)) {
793794
// Get 'srcNode' from which to attempt fusion into 'dstNode'.
@@ -802,7 +803,7 @@ struct GreedyFusion {
802803
continue;
803804

804805
DenseSet<Value> producerConsumerMemrefs;
805-
gatherProducerConsumerMemrefs(srcId, dstId, mdg,
806+
gatherProducerConsumerMemrefs(srcId, dstId, *mdg,
806807
producerConsumerMemrefs);
807808

808809
// Skip if 'srcNode' out edge count on any memref is greater than
@@ -817,7 +818,7 @@ struct GreedyFusion {
817818
// block (e.g., memref block arguments, returned memrefs,
818819
// memrefs passed to function calls, etc.).
819820
DenseSet<Value> srcEscapingMemRefs;
820-
gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs);
821+
gatherEscapingMemrefs(srcNode->id, *mdg, srcEscapingMemRefs);
821822

822823
// Compute an operation list insertion point for the fused loop
823824
// nest which preserves dependences.
@@ -911,7 +912,7 @@ struct GreedyFusion {
911912
// insertion point.
912913
bool removeSrcNode = canRemoveSrcNodeAfterFusion(
913914
srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs,
914-
mdg);
915+
*mdg);
915916

916917
DenseSet<Value> privateMemrefs;
917918
for (Value memref : producerConsumerMemrefs) {

0 commit comments

Comments
 (0)