@@ -8,76 +8,6 @@ using namespace mlir;
88using namespace triton ;
99using namespace triton ::gpu;
1010
11- // ===----------------------------------------------------------------------===//
12- // PartitionGraph
13- // ===----------------------------------------------------------------------===//
14-
15- namespace {
16- // A temporary node structure that can be used to build a graph of partitions.
17- // The consumers have to be precomputed in order for the SCC iterator to have an
18- // acceptable runtime complexity. This assumes the underlying loop is immutable.
19- struct PartitionNode {
20- PartitionNode (const Partition *partition) : partition(partition) {}
21-
22- // The partition this node represents.
23- const Partition *partition;
24- // Partitions that consume the outputs of this partition.
25- SmallVector<std::pair<const PartitionNode *, OpOperand *>> consumers;
26- };
27-
28- // A graph of partitions that can be used to check for cycles and other schedule
29- // invariants.
30- struct PartitionGraph {
31- PartitionGraph (scf::ForOp loop, const WarpSchedule &schedule);
32-
33- PartitionNode root;
34- llvm::MapVector<const Partition *, PartitionNode> nodes;
35- };
36- } // namespace
37-
38- PartitionGraph::PartitionGraph (scf::ForOp loop, const WarpSchedule &schedule)
39- : root(schedule.getRootPartition()) {
40- // Create the nodes at once. Afterwards, the map won't re-allocate and the
41- // pointers will be stable.
42- for (Partition &partition : schedule.getPartitions ())
43- nodes.try_emplace (&partition, &partition);
44-
45- // Wire up the graph. Consider the root node to be consumed by all other
46- // partitions so that it can be used as a virtual root.
47- for (PartitionNode &node : llvm::make_second_range (nodes))
48- root.consumers .emplace_back (&node, nullptr );
49-
50- // Check the users of the partition outputs to wire the rest of the graph.
51- for (auto &[partition, node] : nodes) {
52- auto callback = [&, node = &node](Operation *owner, OpOperand &use) {
53- // Ignore uses in subsequent iterations.
54- if (isa<scf::YieldOp>(owner))
55- return ;
56- PartitionNode &consumer =
57- nodes.find (schedule.getPartition (owner))->second ;
58- node->consumers .emplace_back (&consumer, &use);
59- };
60- schedule.iterateOutputs (loop, partition, callback);
61- }
62- }
63-
64- namespace llvm {
65- template <> struct GraphTraits <PartitionGraph> {
66- using NodeRef = std::pair<const PartitionNode *, mlir::OpOperand *>;
67- static NodeRef getEntryNode (const PartitionGraph &graph) {
68- return {&graph.root , nullptr };
69- }
70-
71- using ChildIteratorType = SmallVector<NodeRef>::const_iterator;
72- static ChildIteratorType child_begin (NodeRef node) {
73- return node.first ->consumers .begin ();
74- }
75- static ChildIteratorType child_end (NodeRef node) {
76- return node.first ->consumers .end ();
77- }
78- };
79- } // namespace llvm
80-
8111// ===----------------------------------------------------------------------===//
8212// WarpSchedule
8313// ===----------------------------------------------------------------------===//
0 commit comments