@@ -50,6 +50,12 @@ int64_t multiplyDivisor(int64_t lhs, int64_t rhs) {
5050 return lhs * rhs;
5151}
5252
53+ RankedTensorType getRankedTensorType (Type ptrTy) {
54+ return isTensorPointerType (ptrTy)
55+ ? cast<RankedTensorType>(cast<PointerType>(ptrTy).getPointeeType ())
56+ : dyn_cast<RankedTensorType>(ptrTy);
57+ }
58+
5359class AxisInfoVisitor {
5460public:
5561 AxisInfoVisitor () = default ;
@@ -409,7 +415,7 @@ class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
409415
410416 int64_t getConstancy (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
411417 int dim) override {
412- auto resTy = dyn_cast<RankedTensorType> (op.getType ());
418+ auto resTy = getRankedTensorType (op.getType ());
413419 if (!resTy)
414420 return BinaryOpVisitorImpl<OpTy>::getConstancy (op, lhs, rhs, dim);
415421 auto shape = resTy.getShape ();
@@ -464,7 +470,7 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
464470private:
465471 int64_t getContiguity (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
466472 int dim) override {
467- auto resTy = dyn_cast<RankedTensorType> (op.getType ());
473+ auto resTy = getRankedTensorType (op.getType ());
468474 if (!resTy)
469475 return BinaryOpVisitorImpl<OpTy>::getContiguity (op, lhs, rhs, dim);
470476 auto shape = resTy.getShape ();
@@ -498,7 +504,7 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
498504
499505 int64_t getConstancy (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
500506 int dim) override {
501- auto resTy = dyn_cast<RankedTensorType> (op.getType ());
507+ auto resTy = getRankedTensorType (op.getType ());
502508 if (!resTy)
503509 return BinaryOpVisitorImpl<OpTy>::getConstancy (op, lhs, rhs, dim);
504510 auto shape = resTy.getShape ();
@@ -647,7 +653,7 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
647653 AxisInfo
648654 getAxisInfo (OpTy op,
649655 ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
650- auto resTy = dyn_cast<RankedTensorType> (op.getType ());
656+ auto resTy = getRankedTensorType (op.getType ());
651657 if (!resTy)
652658 return AxisInfo ();
653659 auto shape = resTy.getShape ();
@@ -995,6 +1001,55 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
9951001 }
9961002};
9971003
1004+ class MakeTensorPtrOpAxisInfoVisitor final
1005+ : public AxisInfoVisitorImpl<triton::MakeTensorPtrOp> {
1006+ public:
1007+ using AxisInfoVisitorImpl<triton::MakeTensorPtrOp>::AxisInfoVisitorImpl;
1008+
1009+ AxisInfo
1010+ getAxisInfo (triton::MakeTensorPtrOp op,
1011+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
1012+ LDBG (" MakeTensorPtrOpAxisInfoVisitor: " << *op);
1013+ assert (op.getShape ().size () == 2 && operands.size () == 7 &&
1014+ " MakeTensorPtrOp should have 2D shape" );
1015+
1016+ AxisInfo ptrInfo = operands[0 ]->getValue ();
1017+ AxisInfo shapeInfo0 = operands[1 ]->getValue ();
1018+ AxisInfo shapeInfo1 = operands[2 ]->getValue ();
1019+ AxisInfo strideInfo0 = operands[3 ]->getValue ();
1020+ AxisInfo strideInfo1 = operands[4 ]->getValue ();
1021+
1022+ std::optional<int64_t > shape0 = shapeInfo0.getConstantValue ();
1023+ std::optional<int64_t > shape1 = shapeInfo1.getConstantValue ();
1024+ std::optional<int64_t > stride0 = strideInfo0.getConstantValue ();
1025+ std::optional<int64_t > stride1 = strideInfo1.getConstantValue ();
1026+
1027+ AxisInfo::DimVectorT contiguity{
1028+ shape0.has_value () && (stride0 == 1 ) ? shape0.value () : 1 ,
1029+ shape1.has_value () && (stride1 == 1 ) ? shape1.value () : 1 };
1030+
1031+ int64_t ptrDivisibility = ptrInfo.getDivisibility ()[0 ];
1032+ int64_t strideDivisibility0 = strideInfo0.getDivisibility ()[0 ];
1033+ int64_t strideDivisibility1 = strideInfo1.getDivisibility ()[0 ];
1034+
1035+ LDBG (" ptrDivisibility: " << ptrDivisibility);
1036+ LDBG (" strideDivisibility0: " << strideDivisibility0);
1037+ LDBG (" strideDivisibility1: " << strideDivisibility1);
1038+
1039+ AxisInfo::DimVectorT divisibility{1 , 1 };
1040+ if (ptrDivisibility > 1 ) {
1041+ if (contiguity[0 ] > 1 )
1042+ divisibility[0 ] = std::min (ptrDivisibility, strideDivisibility1);
1043+ if (contiguity[1 ] > 1 )
1044+ divisibility[1 ] = std::min (ptrDivisibility, strideDivisibility0);
1045+ }
1046+
1047+ AxisInfo::DimVectorT constancy{1 , 1 };
1048+
1049+ return AxisInfo (contiguity, divisibility, constancy);
1050+ }
1051+ };
1052+
9981053// ===----------------------------------------------------------------------===//
9991054// AxisInfoAnalysis
10001055// ===----------------------------------------------------------------------===//
@@ -1042,11 +1097,13 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
10421097 MaxMinOpAxisInfoVisitor<arith::MinSIOp>,
10431098 MaxMinOpAxisInfoVisitor<arith::MinUIOp>>();
10441099 visitors.append <LoadOpAxisInfoVisitor>();
1100+ visitors.append <MakeTensorPtrOpAxisInfoVisitor>();
10451101}
10461102
10471103LogicalResult AxisInfoAnalysis::visitOperation (
10481104 Operation *op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
10491105 ArrayRef<dataflow::Lattice<AxisInfo> *> results) {
1106+ LDBG (" visitOperation: << " << *op);
10501107 // TODO: For sure not the right way to do this
10511108 // but why is scf.if not initialized otherwise?
10521109 for (auto op : operands)
@@ -1204,7 +1261,7 @@ void AxisInfo::initPessimisticStateFromFunc(int argNumber, T funcOp,
12041261}
12051262
12061263unsigned ModuleAxisInfoAnalysis::getPtrContiguity (Value ptr) {
1207- auto tensorTy = dyn_cast<RankedTensorType> (ptr.getType ());
1264+ auto tensorTy = getRankedTensorType (ptr.getType ());
12081265 if (!tensorTy)
12091266 return 1 ;
12101267 auto layout = tensorTy.getEncoding ();
@@ -1226,7 +1283,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
12261283}
12271284
12281285unsigned ModuleAxisInfoAnalysis::getPtrAlignment (Value ptr) {
1229- auto tensorTy = dyn_cast<RankedTensorType> (ptr.getType ());
1286+ auto tensorTy = getRankedTensorType (ptr.getType ());
12301287 if (!tensorTy)
12311288 return 1 ;
12321289 auto *axisInfo = getAxisInfo (ptr);
@@ -1254,7 +1311,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
12541311}
12551312
12561313unsigned ModuleAxisInfoAnalysis::getMaskAlignment (Value mask) {
1257- auto tensorTy = dyn_cast<RankedTensorType> (mask.getType ());
1314+ auto tensorTy = getRankedTensorType (mask.getType ());
12581315 if (!tensorTy)
12591316 return 1 ;
12601317 auto *axisInfo = getAxisInfo (mask);
0 commit comments