11#ifndef TRITON_INTEL_ANALYSIS_AXISINFO_H
22#define TRITON_INTEL_ANALYSIS_AXISINFO_H
33
4- #include " mlir/Analysis/DataFlow/SparseAnalysis.h"
5- #include " llvm/Support/raw_ostream.h"
6-
7- #include " mlir/Support/LLVM.h"
8- #include " triton/Analysis/Utility.h"
9- #include " triton/Dialect/Triton/IR/Dialect.h"
10- #include " triton/Dialect/Triton/IR/Utility.h"
11- #include " triton/Dialect/TritonGPU/IR/Dialect.h"
12-
13- #include < optional>
4+ #include " triton/Analysis/AxisInfo.h"
145
156namespace mlir ::triton::intel {
167
17- // ===----------------------------------------------------------------------===//
18- // AxisInfo
19- // ===----------------------------------------------------------------------===//
20-
21- // / This lattice value represents known information on the axes of a lattice.
22- class AxisInfo {
23- public:
24- typedef SmallVector<int64_t > DimVectorT;
25-
26- public:
27- AxisInfo () : AxisInfo({}, {}, {}) {}
28-
29- AxisInfo (const DimVectorT &contiguity, const DimVectorT &divisibility,
30- const DimVectorT &constancy)
31- : AxisInfo(contiguity, divisibility, constancy, std::nullopt ) {}
32-
33- AxisInfo (const DimVectorT &contiguity, const DimVectorT &divisibility,
34- const DimVectorT &constancy, std::optional<int64_t > constantValue)
35- : contiguity(contiguity), divisibility(divisibility),
36- constancy (constancy), constantValue(constantValue) {
37- assert (divisibility.size () == contiguity.size ());
38- assert (constancy.size () == contiguity.size ());
39- }
40-
41- // contiguity[d] is the length of the shortest sequence of contiguous integers
42- // along dimension d.
43- //
44- // If we have an array of N elements with a contiguity value C, then the array
45- // can be divided into a list of N/C sequences of C contiguous elements.
46- // Since we have N = 2^k, C must be a power of two.
47- //
48- // For example, the 2D array
49- //
50- // [[10, 11, 12, 13, 18, 19, 20, 21],
51- // [20, 21, 22, 23, 28, 29, 30, 31]]
52- //
53- // has contiguity [1, 4], and
54- //
55- // [[12, 16, 20, 24],
56- // [13, 17, 21, 25],
57- // [14, 18, 22, 26],
58- // [15, 19, 23, 27],
59- // [18, 22, 26, 30],
60- // [19, 23, 27, 31]]
61- //
62- // has contiguity [2, 1].
63- int64_t getContiguity (size_t dim) const { return contiguity[dim]; }
64- const DimVectorT &getContiguity () const { return contiguity; }
65-
66- // divisibility[d] is the largest power of two that divides the first element
67- // of all groups of length contiguity[d] along dimension d.
68- //
69- // For example,
70- //
71- // [[10, 11, 12, 13, 18, 19, 20, 21],
72- // [20, 21, 22, 23, 28, 29, 30, 31]]
73- //
74- // has divisibility [1, 2], and
75- //
76- // [[12, 16, 20, 24],
77- // [13, 17, 21, 25],
78- // [14, 18, 22, 26],
79- // [15, 19, 23, 27]]
80- //
81- // has divisibility [4, 1].
82- //
83- // On the other hand,
84- //
85- // [0, 1, 2, 0, 4, 5, 6, 7]
86- //
87- // has divisibility 1 because its contiguity is 1.
88- int64_t getDivisibility (size_t dim) const { return divisibility[dim]; }
89- const DimVectorT &getDivisibility () const { return divisibility; }
90-
91- // constancy[d] is the length of the shortest sequence of repeating integers
92- // along dimension d.
93- //
94- // This is particularly useful to infer the contiguity of operations (e.g.
95- // add) involving a constant.
96- //
97- // If we have an array of N elements, with a constancy value C, then the array
98- // can be divided into a list of N/C sequences of C elements with the same
99- // value. Since we have N = 2^k, C must be a power of two.
100- //
101- // For example
102- //
103- // [[8, 8, 8, 8, 12, 12, 12, 12],
104- // [16, 16, 16, 16, 20, 20, 20, 20]]
105- //
106- // has constancy [1, 4].
107- int64_t getConstancy (size_t dim) const { return constancy[dim]; }
108- const DimVectorT &getConstancy () const { return constancy; }
109-
110- int getRank () const { return contiguity.size (); }
111-
112- std::optional<int64_t > getConstantValue () const { return constantValue; }
113-
114- template <class T >
115- static void
116- initPessimisticStateFromFunc (int argNumber, T funcOp, DimVectorT *contiguity,
117- DimVectorT *divisibility, DimVectorT *constancy);
118-
119- bool operator ==(const AxisInfo &other) const {
120- return contiguity == other.contiguity &&
121- divisibility == other.divisibility && constancy == other.constancy &&
122- constantValue == other.constantValue ;
123- }
124-
125- static AxisInfo getPessimisticValueState (Value value);
126-
127- // The gcd of both arguments for each dimension
128- static AxisInfo join (const AxisInfo &lhs, const AxisInfo &rhs);
129-
130- void print (raw_ostream &os) const {
131- auto print = [&](StringRef name, DimVectorT vec) {
132- os << name << " = [" ;
133- llvm::interleaveComma (vec, os);
134- os << " ]" ;
135- };
136- print (" contiguity" , contiguity);
137- print (" , divisibility" , divisibility);
138- print (" , constancy" , constancy);
139- os << " , constant_value = " ;
140- if (constantValue)
141- os << *constantValue;
142- else
143- os << " <none>" ;
144- }
145-
146- private:
147- DimVectorT contiguity;
148- DimVectorT divisibility;
149- DimVectorT constancy;
150-
151- // The constant value of the lattice if we can infer it.
152- std::optional<int64_t > constantValue;
153- };
154-
1558// Module level axis info analysis based on the call graph, assuming that we do
1569// not have recursive functions.
15710//
15811// Since each function will be called multiple times, we need to calculate the
15912// axis info based on the axis info of all the callers. In the future, we can
16013// perform optimization using function cloning so that each call site will have
16114// unique axis info.
162- using AxisInfoMapT = DenseMap<Value, AxisInfo>;
163- class ModuleAxisInfoAnalysis : public CallGraph <AxisInfoMapT> {
15+ // using AxisInfoMapT = DenseMap<Value, AxisInfo>;
16+ class ModuleAxisInfoAnalysis : public triton ::ModuleAxisInfoAnalysis {
16417public:
16518 explicit ModuleAxisInfoAnalysis (ModuleOp moduleOp)
166- : CallGraph<AxisInfoMapT>(moduleOp) {
19+ : triton::ModuleAxisInfoAnalysis(moduleOp) {
20+ funcMap.clear ();
21+
16722 SmallVector<FunctionOpInterface> funcs;
16823 for (auto root : getRoots ()) {
16924 walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
@@ -187,10 +42,11 @@ class ModuleAxisInfoAnalysis : public CallGraph<AxisInfoMapT> {
18742 }
18843 }
18944
190- AxisInfo *getAxisInfo (Value value) {
45+ AxisInfo *getAxisInfo (Value value) const {
19146 auto funcOp =
19247 value.getParentRegion ()->getParentOfType <FunctionOpInterface>();
193- auto *axisInfoMap = getFuncData (funcOp);
48+ auto *axisInfoMap =
49+ const_cast <ModuleAxisInfoAnalysis *>(this )->getFuncData (funcOp);
19450 if (!axisInfoMap) {
19551 return nullptr ;
19652 }
@@ -201,9 +57,9 @@ class ModuleAxisInfoAnalysis : public CallGraph<AxisInfoMapT> {
20157 return &(it->second );
20258 }
20359
204- unsigned getPtrContiguity (Value ptr);
205- unsigned getPtrAlignment (Value ptr);
206- unsigned getMaskAlignment (Value mask);
60+ unsigned getPtrContiguity (Value ptr) const ;
61+ unsigned getPtrAlignment (Value ptr) const ;
62+ unsigned getMaskAlignment (Value mask) const ;
20763
20864private:
20965 void initialize (FunctionOpInterface funcOp);
0 commit comments