Skip to content

Commit b7604a9

Browse files
[Intel] Sync AxisInfo from upstream (#3558)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent cb6bfd6 commit b7604a9

File tree

1 file changed

+38
-10
lines changed

1 file changed

+38
-10
lines changed

third_party/intel/lib/Analysis/AxisInfo.cpp

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "mlir/Analysis/DataFlowFramework.h"
22
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
3+
#include "mlir/Dialect/UB/IR/UBOps.h"
34
#include "llvm/Support/Debug.h"
45
#include "llvm/Support/raw_ostream.h"
56

@@ -273,6 +274,28 @@ class ConstantOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
273274
}
274275
};
275276

277+
class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl<ub::PoisonOp> {
278+
public:
279+
using AxisInfoVisitorImpl::AxisInfoVisitorImpl;
280+
281+
AxisInfo
282+
getAxisInfo(ub::PoisonOp op,
283+
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
284+
constexpr int64_t largePowerOf2 = int64_t(1) << 32;
285+
// Poison values are never accessed, thus assume optimistic values.
286+
if (auto shape = dyn_cast<mlir::ShapedType>(op.getType())) {
287+
unsigned rank = shape.getRank();
288+
return AxisInfo(
289+
/*contiguity=*/AxisInfo::DimVectorT(rank, largePowerOf2),
290+
/*divisibility=*/AxisInfo::DimVectorT(rank, largePowerOf2),
291+
/*constancy=*/AxisInfo::DimVectorT(shape.getShape()));
292+
}
293+
294+
return AxisInfo(/*contiguity=*/{1}, /*divisibility=*/{largePowerOf2},
295+
/*constancy=*/{1});
296+
}
297+
};
298+
276299
template <typename OpTy>
277300
class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
278301
public:
@@ -946,7 +969,7 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
946969
// Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n
947970
lhsDivisibility = 1;
948971
}
949-
return std::max<int64_t>(1, lhsDivisibility / (1 << shift));
972+
return std::max<int64_t>(1, lhsDivisibility / (int64_t(1) << shift));
950973
}
951974

952975
int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
@@ -1092,6 +1115,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
10921115
// TODO: Remove rules for LLVM::ConstantOp, LLVM::AddOp
10931116
// when scf.for supports integer induction variables
10941117
visitors.append<MakeRangeOpAxisInfoVisitor>();
1118+
visitors.append<PoisonOpAxisInfoVisitor>();
10951119
visitors.append<ConstantOpAxisInfoVisitor<arith::ConstantOp>,
10961120
ConstantOpAxisInfoVisitor<LLVM::ConstantOp>>();
10971121
visitors.append<AddSubOpAxisInfoVisitor<triton::AddPtrOp>,
@@ -1184,15 +1208,16 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
11841208
auto tensorTy = ttgi::getRankedTensorType(ptr.getType());
11851209
if (!tensorTy)
11861210
return 1;
1187-
auto layout = tensorTy.getEncoding();
11881211

1189-
// Here order should be ordered by contiguous first, so the first element
1190-
// should have the largest contiguous.
1191-
auto order = triton::gpu::getOrder(layout);
1212+
// FIXME: This is not as good as it could be, as we don't need to restrict
1213+
// the analysis to one dimension. We should determine contiguity on the
1214+
// flattenOuts() layout
1215+
auto linAttr =
1216+
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
1217+
auto order = linAttr.getOrder();
11921218
unsigned align = getPtrAlignment(ptr);
11931219

1194-
auto uniqueContigPerThread =
1195-
triton::gpu::getUniqueContigPerThread(layout, tensorTy.getShape());
1220+
auto uniqueContigPerThread = linAttr.getContigPerThread();
11961221
assert(order[0] < uniqueContigPerThread.size() &&
11971222
"Unexpected uniqueContigPerThread size");
11981223
unsigned contiguity = uniqueContigPerThread[order[0]];
@@ -1209,8 +1234,9 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
12091234
auto *axisInfo = getAxisInfo(ptr);
12101235
if (!axisInfo)
12111236
return 1;
1212-
auto layout = tensorTy.getEncoding();
1213-
auto order = triton::gpu::getOrder(layout);
1237+
auto linAttr =
1238+
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
1239+
auto order = linAttr.getOrder();
12141240
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
12151241
auto maxContig = axisInfo->getContiguity(order[0]);
12161242
unsigned elemNumBits = isTensorPointerType(ptr.getType())
@@ -1239,7 +1265,9 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
12391265
auto *axisInfo = getAxisInfo(mask);
12401266
if (!axisInfo)
12411267
return 1;
1242-
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
1268+
auto linAttr =
1269+
gpu::toLinearEncoding(tensorTy.getEncoding(), tensorTy.getShape());
1270+
auto maskOrder = linAttr.getOrder();
12431271
auto alignment = std::max<unsigned>(axisInfo->getConstancy(maskOrder[0]), 1);
12441272
LDBG("getMaskAlignment maskOrder[0] " << maskOrder[0] << " alignment "
12451273
<< alignment);

0 commit comments

Comments
 (0)