|
1 | 1 | #ifndef TRITON_INTEL_ANALYSIS_AXISINFO_H |
2 | 2 | #define TRITON_INTEL_ANALYSIS_AXISINFO_H |
3 | 3 |
|
4 | | -#include "mlir/Analysis/DataFlow/SparseAnalysis.h" |
5 | | -#include "llvm/Support/raw_ostream.h" |
6 | | - |
7 | | -#include "mlir/Support/LLVM.h" |
8 | | -#include "triton/Analysis/Utility.h" |
9 | | -#include "triton/Dialect/Triton/IR/Dialect.h" |
10 | | -#include "triton/Dialect/Triton/IR/Utility.h" |
11 | | -#include "triton/Dialect/TritonGPU/IR/Dialect.h" |
12 | | - |
13 | | -#include <optional> |
| 4 | +#include "triton/Analysis/AxisInfo.h" |
14 | 5 |
|
15 | 6 | namespace mlir::triton::intel { |
16 | 7 |
|
17 | | -//===----------------------------------------------------------------------===// |
18 | | -// AxisInfo |
19 | | -//===----------------------------------------------------------------------===// |
20 | | - |
21 | | -/// This lattice value represents known information on the axes of a lattice. |
22 | | -class AxisInfo { |
23 | | -public: |
24 | | - typedef SmallVector<int64_t> DimVectorT; |
25 | | - |
26 | | -public: |
27 | | - AxisInfo() : AxisInfo({}, {}, {}) {} |
28 | | - |
29 | | - AxisInfo(const DimVectorT &contiguity, const DimVectorT &divisibility, |
30 | | - const DimVectorT &constancy) |
31 | | - : AxisInfo(contiguity, divisibility, constancy, std::nullopt) {} |
32 | | - |
33 | | - AxisInfo(const DimVectorT &contiguity, const DimVectorT &divisibility, |
34 | | - const DimVectorT &constancy, std::optional<int64_t> constantValue) |
35 | | - : contiguity(contiguity), divisibility(divisibility), |
36 | | - constancy(constancy), constantValue(constantValue) { |
37 | | - assert(divisibility.size() == contiguity.size()); |
38 | | - assert(constancy.size() == contiguity.size()); |
39 | | - } |
40 | | - |
41 | | - // contiguity[d] is the length of the shortest sequence of contiguous integers |
42 | | - // along dimension d. |
43 | | - // |
44 | | - // If we have an array of N elements with a contiguity value C, then the array |
45 | | - // can be divided into a list of N/C sequences of C contiguous elements. |
46 | | - // Since we have N = 2^k, C must be a power of two. |
47 | | - // |
48 | | - // For example, the 2D array |
49 | | - // |
50 | | - // [[10, 11, 12, 13, 18, 19, 20, 21], |
51 | | - // [20, 21, 22, 23, 28, 29, 30, 31]] |
52 | | - // |
53 | | - // has contiguity [1, 4], and |
54 | | - // |
55 | | - // [[12, 16, 20, 24], |
56 | | - // [13, 17, 21, 25], |
57 | | - // [14, 18, 22, 26], |
58 | | - // [15, 19, 23, 27], |
59 | | - // [18, 22, 26, 30], |
60 | | - // [19, 23, 27, 31]] |
61 | | - // |
62 | | - // has contiguity [2, 1]. |
63 | | - int64_t getContiguity(size_t dim) const { return contiguity[dim]; } |
64 | | - const DimVectorT &getContiguity() const { return contiguity; } |
65 | | - |
66 | | - // divisibility[d] is the largest power of two that divides the first element |
67 | | - // of all groups of length contiguity[d] along dimension d. |
68 | | - // |
69 | | - // For example, |
70 | | - // |
71 | | - // [[10, 11, 12, 13, 18, 19, 20, 21], |
72 | | - // [20, 21, 22, 23, 28, 29, 30, 31]] |
73 | | - // |
74 | | - // has divisibility [1, 2], and |
75 | | - // |
76 | | - // [[12, 16, 20, 24], |
77 | | - // [13, 17, 21, 25], |
78 | | - // [14, 18, 22, 26], |
79 | | - // [15, 19, 23, 27]] |
80 | | - // |
81 | | - // has divisibility [4, 1]. |
82 | | - // |
83 | | - // On the other hand, |
84 | | - // |
85 | | - // [0, 1, 2, 0, 4, 5, 6, 7] |
86 | | - // |
87 | | - // has divisibility 1 because its contiguity is 1. |
88 | | - int64_t getDivisibility(size_t dim) const { return divisibility[dim]; } |
89 | | - const DimVectorT &getDivisibility() const { return divisibility; } |
90 | | - |
91 | | - // constancy[d] is the length of the shortest sequence of repeating integers |
92 | | - // along dimension d. |
93 | | - // |
94 | | - // This is particularly useful to infer the contiguity of operations (e.g. |
95 | | - // add) involving a constant. |
96 | | - // |
97 | | - // If we have an array of N elements, with a constancy value C, then the array |
98 | | - // can be divided into a list of N/C sequences of C elements with the same |
99 | | - // value. Since we have N = 2^k, C must be a power of two. |
100 | | - // |
101 | | - // For example |
102 | | - // |
103 | | - // [[8, 8, 8, 8, 12, 12, 12, 12], |
104 | | - // [16, 16, 16, 16, 20, 20, 20, 20]] |
105 | | - // |
106 | | - // has constancy [1, 4]. |
107 | | - int64_t getConstancy(size_t dim) const { return constancy[dim]; } |
108 | | - const DimVectorT &getConstancy() const { return constancy; } |
109 | | - |
110 | | - int getRank() const { return contiguity.size(); } |
111 | | - |
112 | | - std::optional<int64_t> getConstantValue() const { return constantValue; } |
113 | | - |
114 | | - template <class T> |
115 | | - static void |
116 | | - initPessimisticStateFromFunc(int argNumber, T funcOp, DimVectorT *contiguity, |
117 | | - DimVectorT *divisibility, DimVectorT *constancy); |
118 | | - |
119 | | - bool operator==(const AxisInfo &other) const { |
120 | | - return contiguity == other.contiguity && |
121 | | - divisibility == other.divisibility && constancy == other.constancy && |
122 | | - constantValue == other.constantValue; |
123 | | - } |
124 | | - |
125 | | - static AxisInfo getPessimisticValueState(Value value); |
126 | | - |
127 | | - // The gcd of both arguments for each dimension |
128 | | - static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs); |
129 | | - |
130 | | - void print(raw_ostream &os) const { |
131 | | - auto print = [&](StringRef name, DimVectorT vec) { |
132 | | - os << name << " = ["; |
133 | | - llvm::interleaveComma(vec, os); |
134 | | - os << "]"; |
135 | | - }; |
136 | | - print("contiguity", contiguity); |
137 | | - print(", divisibility", divisibility); |
138 | | - print(", constancy", constancy); |
139 | | - os << ", constant_value = "; |
140 | | - if (constantValue) |
141 | | - os << *constantValue; |
142 | | - else |
143 | | - os << "<none>"; |
144 | | - } |
145 | | - |
146 | | -private: |
147 | | - DimVectorT contiguity; |
148 | | - DimVectorT divisibility; |
149 | | - DimVectorT constancy; |
150 | | - |
151 | | - // The constant value of the lattice if we can infer it. |
152 | | - std::optional<int64_t> constantValue; |
153 | | -}; |
154 | | - |
155 | 8 | // Module level axis info analysis based on the call graph, assuming that we do |
156 | 9 | // not have recursive functions. |
157 | 10 | // |
158 | 11 | // Since each function will be called multiple times, we need to calculate the |
159 | 12 | // axis info based on the axis info of all the callers. In the future, we can |
160 | 13 | // perform optimization using function cloning so that each call site will have |
161 | 14 | // unique axis info. |
162 | | -using AxisInfoMapT = DenseMap<Value, AxisInfo>; |
163 | | -class ModuleAxisInfoAnalysis : public CallGraph<AxisInfoMapT> { |
| 15 | + |
| 16 | +class ModuleAxisInfoAnalysis : public triton::ModuleAxisInfoAnalysis { |
164 | 17 | public: |
165 | 18 | explicit ModuleAxisInfoAnalysis(ModuleOp moduleOp) |
166 | | - : CallGraph<AxisInfoMapT>(moduleOp) { |
| 19 | + : triton::ModuleAxisInfoAnalysis(moduleOp) { |
| 20 | + funcMap.clear(); |
| 21 | + |
167 | 22 | SmallVector<FunctionOpInterface> funcs; |
168 | 23 | for (auto root : getRoots()) { |
169 | 24 | walk<WalkOrder::PreOrder, WalkOrder::PostOrder>( |
|
0 commit comments