44#include " llvm/Support/raw_ostream.h"
55
66#include " intel/include/Analysis/AxisInfo.h"
7+ #include " intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
78#include " triton/Dialect/Triton/IR/Dialect.h"
89
910#define DEBUG_TYPE " intel-axis-info"
1011#define DBGS () (llvm::dbgs() << " [" DEBUG_TYPE " ]: " )
1112#define LDBG (X ) LLVM_DEBUG(DBGS() << X << " \n " )
1213
1314namespace mlir ::triton::intel {
15+
16+ namespace ttgi = mlir::triton::gpu::intel;
1417namespace {
1518
1619int64_t gcdImpl (int64_t a, int64_t b, int64_t *x, int64_t *y) {
@@ -50,12 +53,6 @@ int64_t multiplyDivisor(int64_t lhs, int64_t rhs) {
5053 return lhs * rhs;
5154}
5255
53- RankedTensorType getRankedTensorType (Type ptrTy) {
54- return isTensorPointerType (ptrTy)
55- ? cast<RankedTensorType>(cast<PointerType>(ptrTy).getPointeeType ())
56- : dyn_cast<RankedTensorType>(ptrTy);
57- }
58-
5956class AxisInfoVisitor {
6057public:
6158 AxisInfoVisitor () = default ;
@@ -415,7 +412,7 @@ class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
415412
416413 int64_t getConstancy (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
417414 int dim) override {
418- auto resTy = getRankedTensorType (op.getType ());
415+ auto resTy = ttgi:: getRankedTensorType (op.getType ());
419416 if (!resTy)
420417 return BinaryOpVisitorImpl<OpTy>::getConstancy (op, lhs, rhs, dim);
421418 auto shape = resTy.getShape ();
@@ -470,7 +467,7 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
470467private:
471468 int64_t getContiguity (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
472469 int dim) override {
473- auto resTy = getRankedTensorType (op.getType ());
470+ auto resTy = ttgi:: getRankedTensorType (op.getType ());
474471 if (!resTy)
475472 return BinaryOpVisitorImpl<OpTy>::getContiguity (op, lhs, rhs, dim);
476473 auto shape = resTy.getShape ();
@@ -504,7 +501,7 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
504501
505502 int64_t getConstancy (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
506503 int dim) override {
507- auto resTy = getRankedTensorType (op.getType ());
504+ auto resTy = ttgi:: getRankedTensorType (op.getType ());
508505 if (!resTy)
509506 return BinaryOpVisitorImpl<OpTy>::getConstancy (op, lhs, rhs, dim);
510507 auto shape = resTy.getShape ();
@@ -653,7 +650,7 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
653650 AxisInfo
654651 getAxisInfo (OpTy op,
655652 ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
656- auto resTy = getRankedTensorType (op.getType ());
653+ auto resTy = ttgi:: getRankedTensorType (op.getType ());
657654 if (!resTy)
658655 return AxisInfo ();
659656 auto shape = resTy.getShape ();
@@ -1265,7 +1262,7 @@ void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp,
12651262}
12661263
12671264unsigned ModuleAxisInfoAnalysis::getPtrContiguity (Value ptr) {
1268- auto tensorTy = getRankedTensorType (ptr.getType ());
1265+ auto tensorTy = ttgi:: getRankedTensorType (ptr.getType ());
12691266 if (!tensorTy)
12701267 return 1 ;
12711268 auto layout = tensorTy.getEncoding ();
@@ -1287,7 +1284,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
12871284}
12881285
12891286unsigned ModuleAxisInfoAnalysis::getPtrAlignment (Value ptr) {
1290- auto tensorTy = getRankedTensorType (ptr.getType ());
1287+ auto tensorTy = ttgi:: getRankedTensorType (ptr.getType ());
12911288 if (!tensorTy)
12921289 return 1 ;
12931290 auto *axisInfo = getAxisInfo (ptr);
@@ -1315,7 +1312,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
13151312}
13161313
13171314unsigned ModuleAxisInfoAnalysis::getMaskAlignment (Value mask) {
1318- auto tensorTy = getRankedTensorType (mask.getType ());
1315+ auto tensorTy = ttgi:: getRankedTensorType (mask.getType ());
13191316 if (!tensorTy)
13201317 return 1 ;
13211318 auto *axisInfo = getAxisInfo (mask);
0 commit comments