44#include " llvm/Support/raw_ostream.h"
55
66#include " intel/include/Analysis/AxisInfo.h"
7+ #include " intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
78#include " triton/Dialect/Triton/IR/Dialect.h"
89
910#define DEBUG_TYPE " intel-axis-info"
1011#define DBGS () (llvm::dbgs() << " [" DEBUG_TYPE " ]: " )
1112#define LDBG (X ) LLVM_DEBUG(DBGS() << X << " \n " )
1213
1314namespace mlir ::triton::intel {
15+
16+ namespace ttgi = mlir::triton::gpu::intel;
1417namespace {
1518
1619int64_t gcdImpl (int64_t a, int64_t b, int64_t *x, int64_t *y) {
@@ -50,12 +53,6 @@ int64_t multiplyDivisor(int64_t lhs, int64_t rhs) {
5053 return lhs * rhs;
5154}
5255
53- RankedTensorType getRankedTensorType (Type ptrTy) {
54- return isTensorPointerType (ptrTy)
55- ? cast<RankedTensorType>(cast<PointerType>(ptrTy).getPointeeType ())
56- : dyn_cast<RankedTensorType>(ptrTy);
57- }
58-
5956class AxisInfoVisitor {
6057public:
6158 AxisInfoVisitor () = default ;
@@ -415,7 +412,7 @@ class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
415412
416413 int64_t getConstancy (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
417414 int dim) override {
418- auto resTy = getRankedTensorType (op.getType ());
415+ auto resTy = ttgi:: getRankedTensorType (op.getType ());
419416 if (!resTy)
420417 return BinaryOpVisitorImpl<OpTy>::getConstancy (op, lhs, rhs, dim);
421418 auto shape = resTy.getShape ();
@@ -470,7 +467,7 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
470467private:
471468 int64_t getContiguity (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
472469 int dim) override {
473- auto resTy = getRankedTensorType (op.getType ());
470+ auto resTy = ttgi:: getRankedTensorType (op.getType ());
474471 if (!resTy)
475472 return BinaryOpVisitorImpl<OpTy>::getContiguity (op, lhs, rhs, dim);
476473 auto shape = resTy.getShape ();
@@ -504,7 +501,7 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
504501
505502 int64_t getConstancy (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
506503 int dim) override {
507- auto resTy = getRankedTensorType (op.getType ());
504+ auto resTy = ttgi:: getRankedTensorType (op.getType ());
508505 if (!resTy)
509506 return BinaryOpVisitorImpl<OpTy>::getConstancy (op, lhs, rhs, dim);
510507 auto shape = resTy.getShape ();
@@ -653,7 +650,7 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
653650 AxisInfo
654651 getAxisInfo (OpTy op,
655652 ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
656- auto resTy = getRankedTensorType (op.getType ());
653+ auto resTy = ttgi:: getRankedTensorType (op.getType ());
657654 if (!resTy)
658655 return AxisInfo ();
659656 auto shape = resTy.getShape ();
@@ -1144,8 +1141,11 @@ LogicalResult AxisInfoAnalysis::visitOperation(
11441141
11451142void AxisInfoAnalysis::visitForOpInductionVar (
11461143 scf::ForOp op, ArrayRef<dataflow::Lattice<AxisInfo> *> argLattices) {
1147- const auto &lb = getLatticeElementFor (op, op.getLowerBound ())->getValue ();
1148- const auto &step = getLatticeElementFor (op, op.getStep ())->getValue ();
1144+ ProgramPoint programPoint (op);
1145+ const auto lb =
1146+ getLatticeElementFor (&programPoint, op.getLowerBound ())->getValue ();
1147+ const auto step =
1148+ getLatticeElementFor (&programPoint, op.getStep ())->getValue ();
11491149
11501150 AxisInfo::DimVectorT knownContiguity (1 , 1 );
11511151 AxisInfo::DimVectorT knownDivisibility (1 , 1 );
@@ -1265,7 +1265,7 @@ void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp,
12651265}
12661266
12671267unsigned ModuleAxisInfoAnalysis::getPtrContiguity (Value ptr) {
1268- auto tensorTy = getRankedTensorType (ptr.getType ());
1268+ auto tensorTy = ttgi:: getRankedTensorType (ptr.getType ());
12691269 if (!tensorTy)
12701270 return 1 ;
12711271 auto layout = tensorTy.getEncoding ();
@@ -1287,7 +1287,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
12871287}
12881288
12891289unsigned ModuleAxisInfoAnalysis::getPtrAlignment (Value ptr) {
1290- auto tensorTy = getRankedTensorType (ptr.getType ());
1290+ auto tensorTy = ttgi:: getRankedTensorType (ptr.getType ());
12911291 if (!tensorTy)
12921292 return 1 ;
12931293 auto *axisInfo = getAxisInfo (ptr);
@@ -1315,7 +1315,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
13151315}
13161316
13171317unsigned ModuleAxisInfoAnalysis::getMaskAlignment (Value mask) {
1318- auto tensorTy = getRankedTensorType (mask.getType ());
1318+ auto tensorTy = ttgi:: getRankedTensorType (mask.getType ());
13191319 if (!tensorTy)
13201320 return 1 ;
13211321 auto *axisInfo = getAxisInfo (mask);
0 commit comments