11#include " mlir/Analysis/DataFlowFramework.h"
22#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
3+ #include " mlir/Dialect/UB/IR/UBOps.h"
34#include " llvm/Support/Debug.h"
45#include " llvm/Support/raw_ostream.h"
56
@@ -273,6 +274,28 @@ class ConstantOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
273274 }
274275};
275276
277+ class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl<ub::PoisonOp> {
278+ public:
279+ using AxisInfoVisitorImpl::AxisInfoVisitorImpl;
280+
281+ AxisInfo
282+ getAxisInfo (ub::PoisonOp op,
283+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
284+ constexpr int64_t largePowerOf2 = int64_t (1 ) << 32 ;
285+ // Poison values are never accessed, thus assume optimistic values.
286+ if (auto shape = dyn_cast<mlir::ShapedType>(op.getType ())) {
287+ unsigned rank = shape.getRank ();
288+ return AxisInfo (
289+ /* contiguity=*/ AxisInfo::DimVectorT (rank, largePowerOf2),
290+ /* divisibility=*/ AxisInfo::DimVectorT (rank, largePowerOf2),
291+ /* constancy=*/ AxisInfo::DimVectorT (shape.getShape ()));
292+ }
293+
294+ return AxisInfo (/* contiguity=*/ {1 }, /* divisibility=*/ {largePowerOf2},
295+ /* constancy=*/ {1 });
296+ }
297+ };
298+
276299template <typename OpTy>
277300class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
278301public:
@@ -946,7 +969,7 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
946969 // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n
947970 lhsDivisibility = 1 ;
948971 }
949- return std::max<int64_t >(1 , lhsDivisibility / (1 << shift));
972+ return std::max<int64_t >(1 , lhsDivisibility / (int64_t ( 1 ) << shift));
950973 }
951974
952975 int64_t getConstancy (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
@@ -1092,6 +1115,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
10921115 // TODO: Remove rules for LLVM::ConstantOp, LLVM::AddOp
10931116 // when scf.for supports integer induction variables
10941117 visitors.append <MakeRangeOpAxisInfoVisitor>();
1118+ visitors.append <PoisonOpAxisInfoVisitor>();
10951119 visitors.append <ConstantOpAxisInfoVisitor<arith::ConstantOp>,
10961120 ConstantOpAxisInfoVisitor<LLVM::ConstantOp>>();
10971121 visitors.append <AddSubOpAxisInfoVisitor<triton::AddPtrOp>,
@@ -1184,15 +1208,16 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
11841208 auto tensorTy = ttgi::getRankedTensorType (ptr.getType ());
11851209 if (!tensorTy)
11861210 return 1 ;
1187- auto layout = tensorTy.getEncoding ();
11881211
1189- // Here order should be ordered by contiguous first, so the first element
1190- // should have the largest contiguous.
1191- auto order = triton::gpu::getOrder (layout);
1212+ // FIXME: This is not as good as it could be, as we don't need to restrict
1213+ // the analysis to one dimension. We should determine contiguity on the
1214+ // flattenOuts() layout
1215+ auto linAttr =
1216+ gpu::toLinearEncoding (tensorTy.getEncoding (), tensorTy.getShape ());
1217+ auto order = linAttr.getOrder ();
11921218 unsigned align = getPtrAlignment (ptr);
11931219
1194- auto uniqueContigPerThread =
1195- triton::gpu::getUniqueContigPerThread (layout, tensorTy.getShape ());
1220+ auto uniqueContigPerThread = linAttr.getContigPerThread ();
11961221 assert (order[0 ] < uniqueContigPerThread.size () &&
11971222 " Unexpected uniqueContigPerThread size" );
11981223 unsigned contiguity = uniqueContigPerThread[order[0 ]];
@@ -1209,8 +1234,9 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
12091234 auto *axisInfo = getAxisInfo (ptr);
12101235 if (!axisInfo)
12111236 return 1 ;
1212- auto layout = tensorTy.getEncoding ();
1213- auto order = triton::gpu::getOrder (layout);
1237+ auto linAttr =
1238+ gpu::toLinearEncoding (tensorTy.getEncoding (), tensorTy.getShape ());
1239+ auto order = linAttr.getOrder ();
12141240 auto maxMultipleBytes = axisInfo->getDivisibility (order[0 ]);
12151241 auto maxContig = axisInfo->getContiguity (order[0 ]);
12161242 unsigned elemNumBits = isTensorPointerType (ptr.getType ())
@@ -1239,7 +1265,9 @@ unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
12391265 auto *axisInfo = getAxisInfo (mask);
12401266 if (!axisInfo)
12411267 return 1 ;
1242- auto maskOrder = triton::gpu::getOrder (tensorTy.getEncoding ());
1268+ auto linAttr =
1269+ gpu::toLinearEncoding (tensorTy.getEncoding (), tensorTy.getShape ());
1270+ auto maskOrder = linAttr.getOrder ();
12431271 auto alignment = std::max<unsigned >(axisInfo->getConstancy (maskOrder[0 ]), 1 );
12441272 LDBG (" getMaskAlignment maskOrder[0] " << maskOrder[0 ] << " alignment "
12451273 << alignment);
0 commit comments