Skip to content

Commit d66d424

Browse files
authored
[NFI][Intel] Copy axis analysis (for further customization) (#2453)
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent a2c9ab1 commit d66d424

File tree

7 files changed

+2476
-0
lines changed

7 files changed

+2476
-0
lines changed

bin/RegisterTritonDialects.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@
3737

3838
namespace mlir {
3939
namespace test {
40+
namespace intel {
41+
void registerTestAxisInfoPass();
42+
}
43+
4044
void registerTestAliasPass();
4145
void registerTestAlignmentPass();
4246
void registerTestAllocationPass();
@@ -50,6 +54,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
5054
mlir::registerTritonPasses();
5155
mlir::triton::gpu::registerTritonGPUPasses();
5256
mlir::registerTritonNvidiaGPUPasses();
57+
mlir::test::intel::registerTestAxisInfoPass();
5358
mlir::test::registerTestAliasPass();
5459
mlir::test::registerTestAlignmentPass();
5560
mlir::test::registerTestAllocationPass();

test/Analysis/intel/test-alignment.mlir

Lines changed: 878 additions & 0 deletions
Large diffs are not rendered by default.

test/lib/Analysis/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_library(TritonTestAnalysis
2+
intel/TestAxisInfo.cpp
23
TestAlias.cpp
34
TestAxisInfo.cpp
45
TestAllocation.cpp
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#include "intel/include/Analysis/AxisInfo.h"
2+
#include "mlir/Pass/Pass.h"
3+
4+
using namespace mlir;
5+
using namespace mlir::triton::intel;
6+
7+
namespace {
8+
9+
struct TestAxisInfoPass
10+
: public PassWrapper<TestAxisInfoPass, OperationPass<ModuleOp>> {
11+
12+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAxisInfoPass);
13+
14+
StringRef getArgument() const final { return "test-print-axis-info"; }
15+
StringRef getDescription() const final {
16+
return "print the result of the alignment analysis pass";
17+
}
18+
19+
void runOnOperation() override {
20+
Operation *operation = getOperation();
21+
ModuleOp moduleOp = cast<ModuleOp>(operation);
22+
ModuleAxisInfoAnalysis moduleAxisInfoAnalysis(moduleOp);
23+
moduleOp.walk([&](triton::FuncOp funcOp) {
24+
auto &os = llvm::errs();
25+
auto opName = SymbolTable::getSymbolName(funcOp).getValue().str();
26+
os << "@" << opName << "\n";
27+
funcOp.walk([&](Operation *op) {
28+
if (op->getNumResults() < 1)
29+
return;
30+
for (Value result : op->getResults()) {
31+
result.print(os);
32+
os << " => ";
33+
auto *axisInfo = moduleAxisInfoAnalysis.getAxisInfo(result);
34+
if (axisInfo)
35+
axisInfo->print(os);
36+
os << "\n";
37+
}
38+
});
39+
});
40+
}
41+
};
42+
43+
} // namespace
44+
45+
namespace mlir::test::intel {
46+
void registerTestAxisInfoPass() { PassRegistration<TestAxisInfoPass>(); }
47+
} // namespace mlir::test::intel
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
#ifndef TRITON_INTEL_ANALYSIS_AXISINFO_H
2+
#define TRITON_INTEL_ANALYSIS_AXISINFO_H
3+
4+
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
5+
#include "llvm/Support/raw_ostream.h"
6+
7+
#include "mlir/Support/LLVM.h"
8+
#include "triton/Analysis/Utility.h"
9+
#include "triton/Dialect/Triton/IR/Dialect.h"
10+
#include "triton/Dialect/Triton/IR/Utility.h"
11+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
12+
13+
#include <optional>
14+
#include <type_traits>
15+
16+
namespace mlir::triton::intel {
17+
18+
//===----------------------------------------------------------------------===//
19+
// AxisInfo
20+
//===----------------------------------------------------------------------===//
21+
22+
/// This lattice value represents known information on the axes of a lattice.
23+
class AxisInfo {
24+
public:
25+
typedef SmallVector<int64_t> DimVectorT;
26+
27+
public:
28+
AxisInfo() : AxisInfo({}, {}, {}) {}
29+
30+
AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy)
31+
: AxisInfo(contiguity, divisibility, constancy, std::nullopt) {}
32+
33+
AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy,
34+
std::optional<int64_t> constantValue)
35+
: contiguity(contiguity), divisibility(divisibility),
36+
constancy(constancy), constantValue(constantValue) {
37+
assert(divisibility.size() == contiguity.size());
38+
assert(constancy.size() == contiguity.size());
39+
}
40+
41+
// contiguity[d] is the length of the shortest sequence of contiguous integers
42+
// along dimension d.
43+
//
44+
// If we have an array of N elements with a contiguity value C, then the array
45+
// can be divided into a list of N/C sequences of C contiguous elements.
46+
// Since we have N = 2^k, C must be a power of two.
47+
//
48+
// For example, the 2D array
49+
//
50+
// [[10, 11, 12, 13, 18, 19, 20, 21],
51+
// [20, 21, 22, 23, 28, 29, 30, 31]]
52+
//
53+
// has contiguity [1, 4], and
54+
//
55+
// [[12, 16, 20, 24],
56+
// [13, 17, 21, 25],
57+
// [14, 18, 22, 26],
58+
// [15, 19, 23, 27],
59+
// [18, 22, 26, 30],
60+
// [19, 23, 27, 31]]
61+
//
62+
// has contiguity [2, 1].
63+
int64_t getContiguity(size_t dim) const { return contiguity[dim]; }
64+
const DimVectorT &getContiguity() const { return contiguity; }
65+
66+
// divisibility[d] is the largest power of two that divides the first element
67+
// of all groups of length contiguity[d] along dimension d.
68+
//
69+
// For example,
70+
//
71+
// [[10, 11, 12, 13, 18, 19, 20, 21],
72+
// [20, 21, 22, 23, 28, 29, 30, 31]]
73+
//
74+
// has divisibility [1, 2], and
75+
//
76+
// [[12, 16, 20, 24],
77+
// [13, 17, 21, 25],
78+
// [14, 18, 22, 26],
79+
// [15, 19, 23, 27]]
80+
//
81+
// has divisibility [4, 1].
82+
//
83+
// On the other hand,
84+
//
85+
// [0, 1, 2, 0, 4, 5, 6, 7]
86+
//
87+
// has divisibility 1 because its contiguity is 1.
88+
int64_t getDivisibility(size_t dim) const { return divisibility[dim]; }
89+
const DimVectorT &getDivisibility() const { return divisibility; }
90+
91+
// constancy[d] is the length of the shortest sequence of repeating integers
92+
// along dimension d.
93+
//
94+
// This is particularly useful to infer the contiguity of operations (e.g.
95+
// add) involving a constant.
96+
//
97+
// If we have an array of N elements, with a constancy value C, then the array
98+
// can be divided into a list of N/C sequences of C elements with the same
99+
// value. Since we have N = 2^k, C must be a power of two.
100+
//
101+
// For example
102+
//
103+
// [[8, 8, 8, 8, 12, 12, 12, 12],
104+
// [16, 16, 16, 16, 20, 20, 20, 20]]
105+
//
106+
// has constancy [1, 4].
107+
int64_t getConstancy(size_t dim) const { return constancy[dim]; }
108+
const DimVectorT &getConstancy() const { return constancy; }
109+
110+
int getRank() const { return contiguity.size(); }
111+
112+
std::optional<int64_t> getConstantValue() const { return constantValue; }
113+
114+
template <class T>
115+
static void
116+
initPessimisticStateFromFunc(int argNumber, T funcOp, DimVectorT *contiguity,
117+
DimVectorT *divisibility, DimVectorT *constancy);
118+
119+
bool operator==(const AxisInfo &other) const {
120+
return contiguity == other.contiguity &&
121+
divisibility == other.divisibility && constancy == other.constancy &&
122+
constantValue == other.constantValue;
123+
}
124+
125+
static AxisInfo getPessimisticValueState(Value value);
126+
127+
// The gcd of both arguments for each dimension
128+
static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs);
129+
130+
void print(raw_ostream &os) const {
131+
auto print = [&](StringRef name, DimVectorT vec) {
132+
os << name << " = [";
133+
llvm::interleaveComma(vec, os);
134+
os << "]";
135+
};
136+
print("contiguity", contiguity);
137+
print(", divisibility", divisibility);
138+
print(", constancy", constancy);
139+
os << ", constant_value = ";
140+
if (constantValue)
141+
os << *constantValue;
142+
else
143+
os << "<none>";
144+
}
145+
146+
private:
147+
DimVectorT contiguity;
148+
DimVectorT divisibility;
149+
DimVectorT constancy;
150+
151+
// The constant value of the lattice if we can infer it.
152+
std::optional<int64_t> constantValue;
153+
};
154+
155+
// Module level axis info analysis based on the call graph, assuming that we do
156+
// not have recursive functions.
157+
//
158+
// Since each function will be called multiple times, we need to calculate the
159+
// axis info based on the axis info of all the callers. In the future, we can
160+
// perform optimization using function cloning so that each call site will have
161+
// unique axis info.
162+
using AxisInfoMapT = DenseMap<Value, AxisInfo>;
163+
class ModuleAxisInfoAnalysis : public CallGraph<AxisInfoMapT> {
164+
public:
165+
explicit ModuleAxisInfoAnalysis(ModuleOp moduleOp)
166+
: CallGraph<AxisInfoMapT>(moduleOp) {
167+
SmallVector<FunctionOpInterface> funcs;
168+
for (auto root : getRoots()) {
169+
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
170+
// Pre-order edge walk callback
171+
[](CallOpInterface callOp, FunctionOpInterface funcOp) {},
172+
// Post-order node walk callback
173+
[&](FunctionOpInterface funcOp) {
174+
funcs.push_back(funcOp);
175+
funcMap.try_emplace(funcOp, AxisInfoMapT{});
176+
});
177+
}
178+
SetVector<FunctionOpInterface> sortedFuncs(funcs.begin(), funcs.end());
179+
SymbolTableCollection symbolTable;
180+
for (auto funcOp : llvm::reverse(sortedFuncs)) {
181+
initialize(funcOp);
182+
funcOp.walk([&](CallOpInterface callOp) {
183+
auto callee = dyn_cast<FunctionOpInterface>(
184+
callOp.resolveCallableInTable(&symbolTable));
185+
update(callOp, callee);
186+
});
187+
}
188+
}
189+
190+
AxisInfo *getAxisInfo(Value value) {
191+
auto funcOp =
192+
value.getParentRegion()->getParentOfType<FunctionOpInterface>();
193+
auto *axisInfoMap = getFuncData(funcOp);
194+
if (!axisInfoMap) {
195+
return nullptr;
196+
}
197+
auto it = axisInfoMap->find(value);
198+
if (it == axisInfoMap->end()) {
199+
return nullptr;
200+
}
201+
return &(it->second);
202+
}
203+
204+
unsigned getPtrContiguity(Value ptr);
205+
unsigned getPtrAlignment(Value ptr);
206+
unsigned getMaskAlignment(Value mask);
207+
208+
private:
209+
void initialize(FunctionOpInterface funcOp);
210+
void update(CallOpInterface callOp, FunctionOpInterface funcOp);
211+
};
212+
213+
} // namespace mlir::triton::intel
214+
215+
#endif

0 commit comments

Comments
 (0)