@@ -107,28 +107,37 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl<OpTy> {
107107 const auto &rhsInfo = operands[1 ]->getValue ();
108108 auto rank = lhsInfo.getRank ();
109109 assert (operands.size () == 2 && " Expected two operands" );
110+ AxisInfo::DimVectorT stride;
110111 AxisInfo::DimVectorT contiguity;
111112 AxisInfo::DimVectorT divisibility;
112113 AxisInfo::DimVectorT constancy;
113114 auto constantValue = getConstantValue (op, lhsInfo, rhsInfo);
114115 for (auto d = 0 ; d < rank; ++d) {
115116 if (constantValue.has_value ()) {
117+ stride.push_back (0 );
116118 contiguity.push_back (1 );
117119 constancy.push_back (
118120 std::max (lhsInfo.getConstancy (d), rhsInfo.getConstancy (d)));
119121 divisibility.push_back (
120122 highestPowOf2Divisor<int64_t >(constantValue.value ()));
121123 } else {
124+ stride.push_back (getStride (op, lhsInfo, rhsInfo, d));
122125 contiguity.push_back (getContiguity (op, lhsInfo, rhsInfo, d));
123126 constancy.push_back (getConstancy (op, lhsInfo, rhsInfo, d));
124127 divisibility.push_back (getDivisibility (op, lhsInfo, rhsInfo, d));
125128 }
126129 }
127- return AxisInfo (std::move (contiguity), std::move (divisibility),
128- std::move (constancy), constantValue);
130+ return AxisInfo (std::move (stride), std::move (contiguity),
131+ std::move (divisibility), std::move (constancy),
132+ constantValue);
129133 }
130134
131135protected:
136+ virtual int64_t getStride (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
137+ int dim) {
138+ return -1 ;
139+ }
140+
132141 virtual int64_t getContiguity (OpTy op, const AxisInfo &lhs,
133142 const AxisInfo &rhs, int dim) {
134143 return 1 ;
@@ -252,7 +261,7 @@ class ConstantOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
252261 value = intAttr.getValue ().getZExtValue ();
253262 else
254263 value = boolAttr.getValue () ? 1 : 0 ;
255- return AxisInfo (/* contiguity=*/ {1 },
264+ return AxisInfo (/* stride= */ { 0 }, /* contiguity=*/ {1 },
256265 /* divisibility=*/ {highestPowOf2Divisor (value)},
257266 /* constancy=*/ {1 },
258267 /* knownConstantValue=*/ {value});
@@ -263,6 +272,7 @@ class ConstantOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
263272 int64_t value = splatAttr.template getSplatValue <APInt>().getZExtValue ();
264273 TensorType ty = cast<TensorType>(splatAttr.getType ());
265274 return AxisInfo (
275+ /* stride=*/ AxisInfo::DimVectorT (ty.getRank (), 0 ),
266276 /* contiguity=*/ AxisInfo::DimVectorT (ty.getRank (), 1 ),
267277 /* divisibility=*/
268278 AxisInfo::DimVectorT (ty.getRank (), highestPowOf2Divisor (value)),
@@ -302,6 +312,15 @@ class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
302312 using BinaryOpVisitorImpl<OpTy>::BinaryOpVisitorImpl;
303313
304314private:
315+ int64_t getStride (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
316+ int dim) override {
317+ if (lhs.getStride (dim) < 0 || rhs.getStride (dim) < 0 )
318+ return -1 ;
319+ if (isa<arith::SubIOp>(op))
320+ return std::max (lhs.getStride (dim) - rhs.getStride (dim), int64_t (-1 ));
321+ return lhs.getStride (dim) + rhs.getStride (dim);
322+ }
323+
305324 int64_t getContiguity (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
306325 int dim) override {
307326 // Contiguity assumes an increasing sequence. So for SubIOp contiguous
@@ -373,6 +392,17 @@ class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::MulIOp> {
373392 using BinaryOpVisitorImpl<arith::MulIOp>::BinaryOpVisitorImpl;
374393
375394private:
395+ int64_t getStride (arith::MulIOp op, const AxisInfo &lhs, const AxisInfo &rhs,
396+ int dim) override {
397+ if (lhs.getStride (dim) > 0 && rhs.getConstantValue ().has_value ())
398+ return lhs.getStride (dim) * rhs.getConstantValue ().value ();
399+ if (rhs.getStride (dim) > 0 && lhs.getConstantValue ().has_value ())
400+ return lhs.getConstantValue ().value () * rhs.getStride (dim);
401+ if (lhs.getStride (dim) == 0 || rhs.getStride (dim) == 0 )
402+ return 0 ;
403+ return -1 ;
404+ }
405+
376406 int64_t getContiguity (arith::MulIOp op, const AxisInfo &lhs,
377407 const AxisInfo &rhs, int dim) override {
378408 // lhs * 1 = lhs
@@ -425,6 +455,22 @@ class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
425455 using BinaryOpVisitorImpl<OpTy>::BinaryOpVisitorImpl;
426456
427457private:
458+ int64_t getStride (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
459+ int dim) override {
460+ if (getContiguity (op, lhs, rhs, dim) > 1 )
461+ return 1 ;
462+ if (lhs.getStride (dim) > 0 && rhs.getConstantValue ().has_value () &&
463+ rhs.getConstantValue ().has_value () != 0 &&
464+ lhs.getStride (dim) % rhs.getConstantValue ().value () == 0 )
465+ return lhs.getStride (dim) / rhs.getConstantValue ().value ();
466+ if (rhs.getStride (dim) > 0 && lhs.getConstantValue ().has_value () &&
467+ lhs.getConstantValue ().value () % rhs.getStride (dim) == 0 )
468+ return lhs.getConstantValue ().value () / rhs.getStride (dim);
469+ if (lhs.getStride (dim) == 0 )
470+ return 0 ;
471+ return -1 ;
472+ }
473+
428474 int64_t getContiguity (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
429475 int dim) override {
430476 // lhs / 1 = lhs
@@ -559,16 +605,18 @@ class SplatOpAxisInfoVisitor final
559605 Type _retTy = *op->result_type_begin ();
560606 TensorType retTy = cast<TensorType>(_retTy);
561607 AxisInfo opInfo = operands[0 ]->getValue ();
608+ AxisInfo::DimVectorT stride;
562609 AxisInfo::DimVectorT contiguity;
563610 AxisInfo::DimVectorT divisibility;
564611 AxisInfo::DimVectorT constancy;
565612 for (int d = 0 ; d < retTy.getRank (); ++d) {
613+ stride.push_back (0 );
566614 contiguity.push_back (1 );
567615 divisibility.push_back (opInfo.getDivisibility (0 ));
568616 constancy.push_back (retTy.getShape ()[d]);
569617 }
570- return AxisInfo (std::move (contiguity ), std::move (divisibility ),
571- std::move (constancy),
618+ return AxisInfo (std::move (stride ), std::move (contiguity ),
619+ std::move (divisibility), std::move ( constancy),
572620 operands[0 ]->getValue ().getConstantValue ());
573621 }
574622};
@@ -613,6 +661,7 @@ class ExpandDimsOpAxisInfoVisitor final
613661 getAxisInfo (triton::ExpandDimsOp op,
614662 ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
615663 AxisInfo opInfo = operands[0 ]->getValue ();
664+ AxisInfo::DimVectorT stride = opInfo.getStride ();
616665 AxisInfo::DimVectorT contiguity = opInfo.getContiguity ();
617666 AxisInfo::DimVectorT divisibility = opInfo.getDivisibility ();
618667 AxisInfo::DimVectorT constancy = opInfo.getConstancy ();
@@ -631,11 +680,12 @@ class ExpandDimsOpAxisInfoVisitor final
631680 opInfo.getContiguity (d) > 1 ? 1 : opInfo.getDivisibility (d));
632681 }
633682 }
683+ stride.insert (stride.begin () + op.getAxis (), 0 );
634684 contiguity.insert (contiguity.begin () + op.getAxis (), 1 );
635685 divisibility.insert (divisibility.begin () + op.getAxis (), newDivisibility);
636686 constancy.insert (constancy.begin () + op.getAxis (), 1 );
637- return AxisInfo (std::move (contiguity ), std::move (divisibility ),
638- std::move (constancy),
687+ return AxisInfo (std::move (stride ), std::move (contiguity ),
688+ std::move (divisibility), std::move ( constancy),
639689 operands[0 ]->getValue ().getConstantValue ());
640690 }
641691};
@@ -655,17 +705,19 @@ class BroadcastOpAxisInfoVisitor final
655705 ArrayRef<int64_t > retShape = retTy.getShape ();
656706 ArrayRef<int64_t > opShape = opTy.getShape ();
657707 AxisInfo opInfo = operands[0 ]->getValue ();
708+ AxisInfo::DimVectorT stride;
658709 AxisInfo::DimVectorT contiguity;
659710 AxisInfo::DimVectorT divisibility;
660711 AxisInfo::DimVectorT constancy;
661712 for (int d = 0 ; d < retTy.getRank (); ++d) {
713+ stride.push_back (opInfo.getStride (d));
662714 contiguity.push_back (opShape[d] == 1 ? 1 : opInfo.getContiguity (d));
663715 divisibility.push_back (opInfo.getDivisibility (d));
664716 constancy.push_back (opShape[d] == 1 ? retShape[d]
665717 : opInfo.getConstancy (d));
666718 }
667- return AxisInfo (std::move (contiguity ), std::move (divisibility ),
668- std::move (constancy),
719+ return AxisInfo (std::move (stride ), std::move (contiguity ),
720+ std::move (divisibility), std::move ( constancy),
669721 operands[0 ]->getValue ().getConstantValue ());
670722 }
671723};
@@ -1048,15 +1100,18 @@ class MakeTensorPtrOpAxisInfoVisitor final
10481100 if (rank > 2 )
10491101 return AxisInfo ();
10501102
1051- SmallVector<AxisInfo> strideInfo;
1103+ SmallVector<AxisInfo, 2 > strideInfo;
10521104 for (int i = rank + 1 ; i <= rank * 2 ; ++i)
10531105 strideInfo.emplace_back (operands[i]->getValue ());
10541106
10551107 AxisInfo ptrInfo = operands[0 ]->getValue ();
10561108 int64_t ptrDivisibility = ptrInfo.getDivisibility (0 );
10571109
1058- AxisInfo::DimVectorT contiguity, constancy, divisibility;
1110+ AxisInfo::DimVectorT stride, contiguity, constancy, divisibility;
10591111 for (int dim = 0 ; dim < rank; ++dim) {
1112+ stride.push_back (strideInfo[dim].getConstantValue ().has_value ()
1113+ ? strideInfo[dim].getConstantValue ().value ()
1114+ : -1 );
10601115 contiguity.push_back (
10611116 strideInfo[dim].getConstantValue () == 1 ? blkShape[dim] : 1 );
10621117 divisibility.push_back (
@@ -1069,8 +1124,9 @@ class MakeTensorPtrOpAxisInfoVisitor final
10691124 constancy.push_back (1 );
10701125 }
10711126
1072- auto axisInfo = AxisInfo (std::move (contiguity), std::move (divisibility),
1073- std::move (constancy));
1127+ auto axisInfo =
1128+ AxisInfo (std::move (stride), std::move (contiguity),
1129+ std::move (divisibility), std::move (constancy), std::nullopt );
10741130
10751131 LLVM_DEBUG ({
10761132 std::string axisStr;
@@ -1176,8 +1232,9 @@ LogicalResult AxisInfoAnalysis::visitOperation(
11761232 auto vals = cast<DenseElementsAttr>(attr).getValues <int >();
11771233 newConstancy = AxisInfo::DimVectorT (vals.begin (), vals.end ());
11781234 }
1179- curr = AxisInfo (std::move (newContiguity), std::move (newDivisibility),
1180- std::move (newConstancy), curr.getConstantValue ());
1235+ curr = AxisInfo (curr.getStride (), std::move (newContiguity),
1236+ std::move (newDivisibility), std::move (newConstancy),
1237+ curr.getConstantValue ());
11811238 // join all lattice elements
11821239 for (auto *result : results)
11831240 propagateIfChanged (result, result->join (curr));
0 commit comments