44#include " llvm/Support/raw_ostream.h"
55
66#include " intel/include/Analysis/AxisInfo.h"
7+ #include " intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
78#include " triton/Dialect/Triton/IR/Dialect.h"
89
910#define DEBUG_TYPE " intel-axis-info"
1011#define DBGS () (llvm::dbgs() << " [" DEBUG_TYPE " ]: " )
1112#define LDBG (X ) LLVM_DEBUG(DBGS() << X << " \n " )
1213
1314namespace mlir ::triton::intel {
15+
16+ namespace ttgi = mlir::triton::gpu::intel;
1417namespace {
1518
1619int64_t gcdImpl (int64_t a, int64_t b, int64_t *x, int64_t *y) {
@@ -50,12 +53,6 @@ int64_t multiplyDivisor(int64_t lhs, int64_t rhs) {
5053 return lhs * rhs;
5154}
5255
53- RankedTensorType getRankedTensorType (Type ptrTy) {
54- return isTensorPointerType (ptrTy)
55- ? cast<RankedTensorType>(cast<PointerType>(ptrTy).getPointeeType ())
56- : dyn_cast<RankedTensorType>(ptrTy);
57- }
58-
5956class AxisInfoVisitor {
6057public:
6158 AxisInfoVisitor () = default ;
@@ -415,7 +412,7 @@ class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
415412
416413 int64_t getConstancy (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
417414 int dim) override {
418- auto resTy = getRankedTensorType (op.getType ());
415+ auto resTy = ttgi:: getRankedTensorType (op.getType ());
419416 if (!resTy)
420417 return BinaryOpVisitorImpl<OpTy>::getConstancy (op, lhs, rhs, dim);
421418 auto shape = resTy.getShape ();
@@ -470,7 +467,7 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
470467private:
471468 int64_t getContiguity (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
472469 int dim) override {
473- auto resTy = getRankedTensorType (op.getType ());
470+ auto resTy = ttgi:: getRankedTensorType (op.getType ());
474471 if (!resTy)
475472 return BinaryOpVisitorImpl<OpTy>::getContiguity (op, lhs, rhs, dim);
476473 auto shape = resTy.getShape ();
@@ -504,7 +501,7 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
504501
505502 int64_t getConstancy (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
506503 int dim) override {
507- auto resTy = getRankedTensorType (op.getType ());
504+ auto resTy = ttgi:: getRankedTensorType (op.getType ());
508505 if (!resTy)
509506 return BinaryOpVisitorImpl<OpTy>::getConstancy (op, lhs, rhs, dim);
510507 auto shape = resTy.getShape ();
@@ -653,7 +650,7 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
653650 AxisInfo
654651 getAxisInfo (OpTy op,
655652 ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
656- auto resTy = getRankedTensorType (op.getType ());
653+ auto resTy = ttgi:: getRankedTensorType (op.getType ());
657654 if (!resTy)
658655 return AxisInfo ();
659656 auto shape = resTy.getShape ();
@@ -1268,7 +1265,7 @@ void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp,
12681265}
12691266
12701267unsigned ModuleAxisInfoAnalysis::getPtrContiguity (Value ptr) {
1271- auto tensorTy = getRankedTensorType (ptr.getType ());
1268+ auto tensorTy = ttgi:: getRankedTensorType (ptr.getType ());
12721269 if (!tensorTy)
12731270 return 1 ;
12741271 auto layout = tensorTy.getEncoding ();
@@ -1290,7 +1287,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
12901287}
12911288
12921289unsigned ModuleAxisInfoAnalysis::getPtrAlignment (Value ptr) {
1293- auto tensorTy = getRankedTensorType (ptr.getType ());
1290+ auto tensorTy = ttgi:: getRankedTensorType (ptr.getType ());
12941291 if (!tensorTy)
12951292 return 1 ;
12961293 auto *axisInfo = getAxisInfo (ptr);
@@ -1318,7 +1315,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
13181315}
13191316
13201317unsigned ModuleAxisInfoAnalysis::getMaskAlignment (Value mask) {
1321- auto tensorTy = getRankedTensorType (mask.getType ());
1318+ auto tensorTy = ttgi:: getRankedTensorType (mask.getType ());
13221319 if (!tensorTy)
13231320 return 1 ;
13241321 auto *axisInfo = getAxisInfo (mask);
0 commit comments