Skip to content

Commit dd36f6d

Browse files
authored
[NFI]: Make intel AxisInfo analysis derive from upstream implementation (#2598)
The Intel version of the AxisInfo analysis added support for blocked pointers (introduced in a previous PR); in order to use it in places where the upstream analysis is required this PR makes it a derived class of the upstream analysis. --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent b8fc4b9 commit dd36f6d

File tree

5 files changed

+13
-262
lines changed

5 files changed

+13
-262
lines changed

third_party/intel/include/Analysis/AxisInfo.h

Lines changed: 6 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -1,169 +1,24 @@
11
#ifndef TRITON_INTEL_ANALYSIS_AXISINFO_H
22
#define TRITON_INTEL_ANALYSIS_AXISINFO_H
33

4-
#include "mlir/Analysis/DataFlow/SparseAnalysis.h"
5-
#include "llvm/Support/raw_ostream.h"
6-
7-
#include "mlir/Support/LLVM.h"
8-
#include "triton/Analysis/Utility.h"
9-
#include "triton/Dialect/Triton/IR/Dialect.h"
10-
#include "triton/Dialect/Triton/IR/Utility.h"
11-
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
12-
13-
#include <optional>
4+
#include "triton/Analysis/AxisInfo.h"
145

156
namespace mlir::triton::intel {
167

17-
//===----------------------------------------------------------------------===//
18-
// AxisInfo
19-
//===----------------------------------------------------------------------===//
20-
21-
/// This lattice value represents known information on the axes of a lattice.
22-
class AxisInfo {
23-
public:
24-
typedef SmallVector<int64_t> DimVectorT;
25-
26-
public:
27-
AxisInfo() : AxisInfo({}, {}, {}) {}
28-
29-
AxisInfo(const DimVectorT &contiguity, const DimVectorT &divisibility,
30-
const DimVectorT &constancy)
31-
: AxisInfo(contiguity, divisibility, constancy, std::nullopt) {}
32-
33-
AxisInfo(const DimVectorT &contiguity, const DimVectorT &divisibility,
34-
const DimVectorT &constancy, std::optional<int64_t> constantValue)
35-
: contiguity(contiguity), divisibility(divisibility),
36-
constancy(constancy), constantValue(constantValue) {
37-
assert(divisibility.size() == contiguity.size());
38-
assert(constancy.size() == contiguity.size());
39-
}
40-
41-
// contiguity[d] is the length of the shortest sequence of contiguous integers
42-
// along dimension d.
43-
//
44-
// If we have an array of N elements with a contiguity value C, then the array
45-
// can be divided into a list of N/C sequences of C contiguous elements.
46-
// Since we have N = 2^k, C must be a power of two.
47-
//
48-
// For example, the 2D array
49-
//
50-
// [[10, 11, 12, 13, 18, 19, 20, 21],
51-
// [20, 21, 22, 23, 28, 29, 30, 31]]
52-
//
53-
// has contiguity [1, 4], and
54-
//
55-
// [[12, 16, 20, 24],
56-
// [13, 17, 21, 25],
57-
// [14, 18, 22, 26],
58-
// [15, 19, 23, 27],
59-
// [18, 22, 26, 30],
60-
// [19, 23, 27, 31]]
61-
//
62-
// has contiguity [2, 1].
63-
int64_t getContiguity(size_t dim) const { return contiguity[dim]; }
64-
const DimVectorT &getContiguity() const { return contiguity; }
65-
66-
// divisibility[d] is the largest power of two that divides the first element
67-
// of all groups of length contiguity[d] along dimension d.
68-
//
69-
// For example,
70-
//
71-
// [[10, 11, 12, 13, 18, 19, 20, 21],
72-
// [20, 21, 22, 23, 28, 29, 30, 31]]
73-
//
74-
// has divisibility [1, 2], and
75-
//
76-
// [[12, 16, 20, 24],
77-
// [13, 17, 21, 25],
78-
// [14, 18, 22, 26],
79-
// [15, 19, 23, 27]]
80-
//
81-
// has divisibility [4, 1].
82-
//
83-
// On the other hand,
84-
//
85-
// [0, 1, 2, 0, 4, 5, 6, 7]
86-
//
87-
// has divisibility 1 because its contiguity is 1.
88-
int64_t getDivisibility(size_t dim) const { return divisibility[dim]; }
89-
const DimVectorT &getDivisibility() const { return divisibility; }
90-
91-
// constancy[d] is the length of the shortest sequence of repeating integers
92-
// along dimension d.
93-
//
94-
// This is particularly useful to infer the contiguity of operations (e.g.
95-
// add) involving a constant.
96-
//
97-
// If we have an array of N elements, with a constancy value C, then the array
98-
// can be divided into a list of N/C sequences of C elements with the same
99-
// value. Since we have N = 2^k, C must be a power of two.
100-
//
101-
// For example
102-
//
103-
// [[8, 8, 8, 8, 12, 12, 12, 12],
104-
// [16, 16, 16, 16, 20, 20, 20, 20]]
105-
//
106-
// has constancy [1, 4].
107-
int64_t getConstancy(size_t dim) const { return constancy[dim]; }
108-
const DimVectorT &getConstancy() const { return constancy; }
109-
110-
int getRank() const { return contiguity.size(); }
111-
112-
std::optional<int64_t> getConstantValue() const { return constantValue; }
113-
114-
template <class T>
115-
static void
116-
initPessimisticStateFromFunc(int argNumber, T funcOp, DimVectorT *contiguity,
117-
DimVectorT *divisibility, DimVectorT *constancy);
118-
119-
bool operator==(const AxisInfo &other) const {
120-
return contiguity == other.contiguity &&
121-
divisibility == other.divisibility && constancy == other.constancy &&
122-
constantValue == other.constantValue;
123-
}
124-
125-
static AxisInfo getPessimisticValueState(Value value);
126-
127-
// The gcd of both arguments for each dimension
128-
static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs);
129-
130-
void print(raw_ostream &os) const {
131-
auto print = [&](StringRef name, DimVectorT vec) {
132-
os << name << " = [";
133-
llvm::interleaveComma(vec, os);
134-
os << "]";
135-
};
136-
print("contiguity", contiguity);
137-
print(", divisibility", divisibility);
138-
print(", constancy", constancy);
139-
os << ", constant_value = ";
140-
if (constantValue)
141-
os << *constantValue;
142-
else
143-
os << "<none>";
144-
}
145-
146-
private:
147-
DimVectorT contiguity;
148-
DimVectorT divisibility;
149-
DimVectorT constancy;
150-
151-
// The constant value of the lattice if we can infer it.
152-
std::optional<int64_t> constantValue;
153-
};
154-
1558
// Module level axis info analysis based on the call graph, assuming that we do
1569
// not have recursive functions.
15710
//
15811
// Since each function will be called multiple times, we need to calculate the
15912
// axis info based on the axis info of all the callers. In the future, we can
16013
// perform optimization using function cloning so that each call site will have
16114
// unique axis info.
162-
using AxisInfoMapT = DenseMap<Value, AxisInfo>;
163-
class ModuleAxisInfoAnalysis : public CallGraph<AxisInfoMapT> {
15+
16+
class ModuleAxisInfoAnalysis : public triton::ModuleAxisInfoAnalysis {
16417
public:
16518
explicit ModuleAxisInfoAnalysis(ModuleOp moduleOp)
166-
: CallGraph<AxisInfoMapT>(moduleOp) {
19+
: triton::ModuleAxisInfoAnalysis(moduleOp) {
20+
funcMap.clear();
21+
16722
SmallVector<FunctionOpInterface> funcs;
16823
for (auto root : getRoots()) {
16924
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(

third_party/intel/include/Dialect/TritonIntelGPU/IR/Utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ inline unsigned getNumElementsPerThread(
2828
? cast<RankedTensorType>(cast<PointerType>(valTy).getPointeeType())
2929
: cast<RankedTensorType>(valTy);
3030
auto shapePerCTA = getShapePerCTA(ty);
31-
mlir::triton::intel::AxisInfo &valInfo = *axisInfoAnalysis.getAxisInfo(val);
31+
mlir::triton::AxisInfo &valInfo = *axisInfoAnalysis.getAxisInfo(val);
3232

3333
unsigned elemNumBits = getElementBitWidth(ty);
3434
unsigned elemNumBytes = std::max(elemNumBits / 8, 1u);

third_party/intel/lib/Analysis/AxisInfo.cpp

Lines changed: 3 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,112 +1159,6 @@ void AxisInfoAnalysis::visitForOpInductionVar(
11591159

11601160
} // anonymous namespace
11611161

1162-
template <class T>
1163-
void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp,
1164-
DimVectorT *contiguity,
1165-
DimVectorT *divisibility,
1166-
DimVectorT *constancy) {
1167-
// liast of attributes that we care about
1168-
SmallVector<std::pair<DimVectorT *, std::string>> retVecs;
1169-
retVecs.push_back({contiguity, "tt.contiguity"});
1170-
retVecs.push_back({divisibility, "tt.divisibility"});
1171-
retVecs.push_back({constancy, "tt.constancy"});
1172-
// initialize attributes one by one
1173-
for (auto [vec, attrName] : retVecs) {
1174-
Attribute attr = funcOp.getArgAttr(argNumber, attrName);
1175-
if (auto int_attr = dyn_cast_or_null<IntegerAttr>(attr))
1176-
*vec = DimVectorT(contiguity->size(), int_attr.getValue().getZExtValue());
1177-
if (auto dense_attr = dyn_cast_or_null<DenseElementsAttr>(attr)) {
1178-
auto vals = dense_attr.getValues<int>();
1179-
*vec = DimVectorT(vals.begin(), vals.end());
1180-
}
1181-
}
1182-
}
1183-
1184-
/*static*/ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
1185-
auto rank = 1;
1186-
if (TensorType ty = dyn_cast<TensorType>(value.getType()))
1187-
rank = ty.getRank();
1188-
if (triton::PointerType ty = dyn_cast<triton::PointerType>(value.getType()))
1189-
if (TensorType elemTy = dyn_cast<TensorType>(ty.getPointeeType()))
1190-
rank = elemTy.getRank();
1191-
1192-
DimVectorT knownContiguity(rank, 1);
1193-
DimVectorT knownDivisibility(rank, 1);
1194-
DimVectorT knownConstancy(rank, 1);
1195-
1196-
BlockArgument blockArg = dyn_cast<BlockArgument>(value);
1197-
1198-
if (blockArg && blockArg.getOwner()->isEntryBlock()) {
1199-
Operation *op = blockArg.getOwner()->getParentOp();
1200-
if (auto fun = dyn_cast<FunctionOpInterface>(op))
1201-
initPessimisticStateFromFunc(blockArg.getArgNumber(), fun,
1202-
&knownContiguity, &knownDivisibility,
1203-
&knownConstancy);
1204-
// llvm codegen check alignment to generate vector load/store
1205-
// would be nice if this wasn't the case
1206-
else if (auto fun = dyn_cast<LLVM::LLVMFuncOp>(op))
1207-
initPessimisticStateFromFunc(blockArg.getArgNumber(), fun,
1208-
&knownContiguity, &knownDivisibility,
1209-
&knownConstancy);
1210-
else if (isa<RegionBranchOpInterface>(op)) {
1211-
// scf::ForOp, scf::IfOp, scf::WhileOp
1212-
// Control flow operations are initialized with "unknown" state:
1213-
// the maximum possible divisibility, contiguity, and constancy.
1214-
knownDivisibility = DimVectorT(rank, highestPowOf2Divisor<int64_t>(0));
1215-
knownConstancy = DimVectorT(rank, highestPowOf2Divisor<int64_t>(0));
1216-
knownContiguity = DimVectorT(rank, highestPowOf2Divisor<int64_t>(0));
1217-
}
1218-
} else if (Operation *op = value.getDefiningOp()) {
1219-
if (isa<RegionBranchOpInterface>(op)) {
1220-
// scf::ForOp, scf::IfOp, scf::WhileOp
1221-
// Control flow operations are initialized with "unknown" state:
1222-
// the maximum possible divisibility, contiguity, and constancy.
1223-
knownDivisibility = DimVectorT(rank, highestPowOf2Divisor<int64_t>(0));
1224-
knownConstancy = DimVectorT(rank, highestPowOf2Divisor<int64_t>(0));
1225-
knownContiguity = DimVectorT(rank, highestPowOf2Divisor<int64_t>(0));
1226-
}
1227-
// Other operations are conservatively initialized with the lowest possible
1228-
// divisibility, contiguity, and constancy unless they have specified.
1229-
if (Attribute attr = op->getDiscardableAttr("tt.divisibility")) {
1230-
auto vals = cast<DenseElementsAttr>(attr).getValues<int>();
1231-
knownDivisibility = DimVectorT(vals.begin(), vals.end());
1232-
}
1233-
if (Attribute attr = op->getDiscardableAttr("tt.contiguity")) {
1234-
auto vals = cast<DenseElementsAttr>(attr).getValues<int>();
1235-
knownContiguity = DimVectorT(vals.begin(), vals.end());
1236-
}
1237-
if (Attribute attr = op->getDiscardableAttr("tt.constancy")) {
1238-
auto vals = cast<DenseElementsAttr>(attr).getValues<int>();
1239-
knownConstancy = DimVectorT(vals.begin(), vals.end());
1240-
}
1241-
}
1242-
1243-
return AxisInfo(knownContiguity, knownDivisibility, knownConstancy);
1244-
}
1245-
1246-
/*static*/ AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) {
1247-
// If one argument is not initialized, return the other.
1248-
if (lhs.getRank() == 0)
1249-
return rhs;
1250-
if (rhs.getRank() == 0)
1251-
return lhs;
1252-
DimVectorT contiguity;
1253-
DimVectorT divisibility;
1254-
DimVectorT constancy;
1255-
for (auto d = 0; d < lhs.getRank(); ++d) {
1256-
contiguity.push_back(gcd(lhs.getContiguity(d), rhs.getContiguity(d)));
1257-
divisibility.push_back(gcd(lhs.getDivisibility(d), rhs.getDivisibility(d)));
1258-
constancy.push_back(gcd(lhs.getConstancy(d), rhs.getConstancy(d)));
1259-
}
1260-
std::optional<int64_t> constantValue;
1261-
if (lhs.getConstantValue().has_value() &&
1262-
rhs.getConstantValue().has_value() &&
1263-
lhs.getConstantValue() == rhs.getConstantValue())
1264-
constantValue = lhs.getConstantValue();
1265-
return AxisInfo(contiguity, divisibility, constancy, constantValue);
1266-
}
1267-
12681162
unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
12691163
auto tensorTy = ttgi::getRankedTensorType(ptr.getType());
12701164
if (!tensorTy)
@@ -1298,7 +1192,9 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
12981192
auto order = triton::gpu::getOrder(layout);
12991193
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
13001194
auto maxContig = axisInfo->getContiguity(order[0]);
1301-
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
1195+
unsigned elemNumBits = isTensorPointerType(ptr.getType())
1196+
? tensorTy.getElementType().getIntOrFloatBitWidth()
1197+
: triton::getPointeeBitWidth(tensorTy);
13021198
auto elemNumBytes = std::max<unsigned>(elemNumBits / 8, 1);
13031199
auto maxMultiple = std::max<int64_t>(maxMultipleBytes / elemNumBytes, 1);
13041200
unsigned alignment = std::min(maxMultiple, maxContig);

third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
#include "TargetInfo.h"
55
#include "TritonGPUToLLVMBase.h"
6+
#include "intel/include/Analysis/AxisInfo.h"
67
#include "intel/include/TritonIntelGPUToLLVM/TypeConverter.h"
7-
#include "triton/Analysis/AxisInfo.h"
88

99
namespace mlir::triton::intel {
1010

third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
88
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
99

10+
#include "intel/include/Analysis/AxisInfo.h"
1011
#include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h"
1112
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
1213
#include "intel/include/GPUToTritonGEN/GPUToTritonGENPass.h"
1314
#include "intel/include/TritonGENToLLVM/TritonGENToLLVMPass.h"
1415
#include "intel/include/TritonIntelGPUToLLVM/Passes.h"
1516

1617
#include "triton/Analysis/Allocation.h"
17-
#include "triton/Analysis/AxisInfo.h"
1818
#include "triton/Analysis/Membar.h"
1919
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
2020
#include "triton/Dialect/Triton/IR/Dialect.h"
@@ -114,7 +114,7 @@ struct ConvertTritonGPUToLLVM
114114
return signalPassFailure();
115115
}
116116

117-
ModuleAxisInfoAnalysis axisInfoAnalysis(mod);
117+
intel::ModuleAxisInfoAnalysis axisInfoAnalysis(mod);
118118
OpBuilder::InsertPoint indexInsertPoint;
119119

120120
RewritePatternSet patterns(context);

0 commit comments

Comments
 (0)