@@ -91,23 +91,26 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl<OpTy> {
9191 auto lhsInfo = operands[0 ]->getValue ();
9292 auto rhsInfo = operands[1 ]->getValue ();
9393 auto rank = lhsInfo.getRank ();
94+ assert (isa<RankedTensorType>(op.getType ()) ||
95+ rank == 1 && " Expected ranked tensor or scalar" );
9496 assert (operands.size () == 2 && " Expected two operands" );
97+ auto constantValue = getConstantValue (op, lhsInfo, rhsInfo);
98+ if (constantValue.has_value ()) {
99+ auto resTy = dyn_cast<RankedTensorType>(op.getType ());
100+ AxisInfo::DimVectorT constancy =
101+ resTy ? to_vector (resTy.getShape ()) : AxisInfo::DimVectorT (rank, 1 );
102+ AxisInfo::DimVectorT contiguity (rank, 1 );
103+ AxisInfo::DimVectorT divisibility (
104+ rank, highestPowOf2Divisor<int64_t >(constantValue.value ()));
105+ return AxisInfo (contiguity, divisibility, constancy, constantValue);
106+ }
95107 AxisInfo::DimVectorT contiguity;
96108 AxisInfo::DimVectorT divisibility;
97109 AxisInfo::DimVectorT constancy;
98- auto constantValue = getConstantValue (op, lhsInfo, rhsInfo);
99110 for (auto d = 0 ; d < rank; ++d) {
100- if (constantValue.has_value ()) {
101- contiguity.push_back (1 );
102- constancy.push_back (
103- std::max (lhsInfo.getConstancy (d), rhsInfo.getConstancy (d)));
104- divisibility.push_back (
105- highestPowOf2Divisor<int64_t >(constantValue.value ()));
106- } else {
107- contiguity.push_back (getContiguity (op, lhsInfo, rhsInfo, d));
108- constancy.push_back (getConstancy (op, lhsInfo, rhsInfo, d));
109- divisibility.push_back (getDivisibility (op, lhsInfo, rhsInfo, d));
110- }
111+ contiguity.push_back (getContiguity (op, lhsInfo, rhsInfo, d));
112+ constancy.push_back (getConstancy (op, lhsInfo, rhsInfo, d));
113+ divisibility.push_back (getDivisibility (op, lhsInfo, rhsInfo, d));
111114 }
112115 return AxisInfo (contiguity, divisibility, constancy, constantValue);
113116 }
@@ -125,9 +128,8 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl<OpTy> {
125128
126129 virtual int64_t getConstancy (OpTy op, const AxisInfo &lhs,
127130 const AxisInfo &rhs, int dim) {
128- return 1 ;
131+ return gcd (lhs. getConstancy (dim), rhs. getConstancy (dim)) ;
129132 }
130-
131133 virtual std::optional<int64_t > getConstantValue (OpTy op, const AxisInfo &lhs,
132134 const AxisInfo &rhs) {
133135 return {};
@@ -192,6 +194,26 @@ class CastOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
192194 }
193195};
194196
197+ class UnrealizedConversionCastOpAxisInfoVisitor final
198+ : public AxisInfoVisitorImpl<mlir::UnrealizedConversionCastOp> {
199+ public:
200+ using AxisInfoVisitorImpl<
201+ mlir::UnrealizedConversionCastOp>::AxisInfoVisitorImpl;
202+
203+ AxisInfo
204+ getAxisInfo (mlir::UnrealizedConversionCastOp op,
205+ ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
206+ auto tensorType = dyn_cast<RankedTensorType>(op.getResultTypes ()[0 ]);
207+ if (tensorType &&
208+ tensorType.getRank () != operands[0 ]->getValue ().getRank ()) {
209+ // Do not propagate AxisInfo with incorrect rank. This can cause a crash
210+ // in future visitor applications.
211+ return AxisInfo::getPessimisticValueState (op->getResult (0 ));
212+ }
213+ return operands[0 ]->getValue ();
214+ }
215+ };
216+
195217class MakeRangeOpAxisInfoVisitor final
196218 : public AxisInfoVisitorImpl<triton::MakeRangeOp> {
197219public:
@@ -308,11 +330,6 @@ class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
308330 return gcd (lhs.getDivisibility (dim), rhsDivisibility);
309331 }
310332
311- int64_t getConstancy (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
312- int dim) override {
313- return gcd (lhs.getConstancy (dim), rhs.getConstancy (dim));
314- }
315-
316333 std::optional<int64_t > getConstantValue (OpTy op, const AxisInfo &lhs,
317334 const AxisInfo &rhs) override {
318335 if (lhs.getConstantValue ().has_value () &&
@@ -355,11 +372,6 @@ class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::MulIOp> {
355372 return std::max (lhsContiguity, rhsContiguity);
356373 }
357374
358- int64_t getConstancy (arith::MulIOp op, const AxisInfo &lhs,
359- const AxisInfo &rhs, int dim) override {
360- return gcd (lhs.getConstancy (dim), rhs.getConstancy (dim));
361- }
362-
363375 int64_t getDivisibility (arith::MulIOp op, const AxisInfo &lhs,
364376 const AxisInfo &rhs, int dim) override {
365377 auto lhsDivisibility = lhs.getDivisibility (dim);
@@ -379,9 +391,13 @@ class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::MulIOp> {
379391
380392 std::optional<int64_t > getConstantValue (arith::MulIOp op, const AxisInfo &lhs,
381393 const AxisInfo &rhs) override {
382- if (lhs.getConstantValue ().has_value () &&
383- rhs.getConstantValue ().has_value ())
384- return {lhs.getConstantValue ().value () * rhs.getConstantValue ().value ()};
394+ auto lhsConst = lhs.getConstantValue ();
395+ auto rhsConst = rhs.getConstantValue ();
396+ if (lhsConst.has_value () && rhsConst.has_value ())
397+ return {lhsConst.value () * rhsConst.value ()};
398+ if ((lhsConst.has_value () && lhsConst.value () == 0 ) ||
399+ (rhsConst.has_value () && rhsConst.value () == 0 ))
400+ return 0 ;
385401 return {};
386402 }
387403};
@@ -404,12 +420,11 @@ class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
404420 int64_t getConstancy (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
405421 int dim) override {
406422 auto resTy = dyn_cast<RankedTensorType>(op.getType ());
423+ auto constancy = BinaryOpVisitorImpl<OpTy>::getConstancy (op, lhs, rhs, dim);
407424 if (!resTy)
408- return BinaryOpVisitorImpl<OpTy>:: getConstancy (op, lhs, rhs, dim) ;
425+ return constancy ;
409426 auto shape = resTy.getShape ();
410- // Case 1: both lhs and rhs are constants.
411- auto constancy = gcd (lhs.getConstancy (dim), rhs.getConstancy (dim));
412- // Case 2: lhs contiguous, rhs constant.
427+ // Case: lhs contiguous, rhs constant.
413428 // lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n
414429 // rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p
415430 // lhs / rhs = d_lhs * k / (d_rhs * p), (d_lhs * k + 1) / (d_rhs * p),
@@ -506,15 +521,15 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
506521
507522 int64_t getConstancy (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
508523 int dim) override {
524+ auto constancy = BinaryOpVisitorImpl<OpTy>::getConstancy (op, lhs, rhs, dim);
509525 auto resTy = dyn_cast<RankedTensorType>(op.getType ());
510526 if (!resTy)
511- return BinaryOpVisitorImpl<OpTy>::getConstancy (op, lhs, rhs, dim);
512- auto shape = resTy.getShape ();
513- // lhs % 1 = 0
514- return rhs.getConstantValue ().has_value () &&
515- rhs.getConstantValue ().value () == 1
516- ? shape[dim]
517- : gcd (lhs.getConstancy (dim), rhs.getConstancy (dim));
527+ return constancy;
528+ // Case: lhs % 1 = 0
529+ if (rhs.getConstantValue ().has_value () &&
530+ rhs.getConstantValue ().value () == 1 )
531+ return resTy.getDimSize (dim);
532+ return constancy;
518533 }
519534
520535 std::optional<int64_t > getConstantValue (OpTy op, const AxisInfo &lhs,
@@ -669,7 +684,7 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
669684 int64_t constHint = 1 ;
670685 if (lhsInfo.getConstantValue ().has_value () &&
671686 rhsInfo.getConstantValue ().has_value ()) {
672- constHint = lhsInfo. getConstancy (d) ;
687+ constHint = shape[d] ;
673688 constantValue =
674689 compare (getPredicate (op), lhsInfo.getConstantValue ().value (),
675690 rhsInfo.getConstantValue ().value ())
@@ -828,6 +843,13 @@ class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
828843 rhsInfo.getConstantValue ().has_value () &&
829844 lhsInfo.getConstantValue () == rhsInfo.getConstantValue ())
830845 constantValue = lhsInfo.getConstantValue ();
846+
847+ if (constantValue.has_value ()) {
848+ auto resTy = dyn_cast<RankedTensorType>(op.getType ());
849+ assert (resTy || rank == 1 );
850+ constancy =
851+ resTy ? to_vector (resTy.getShape ()) : AxisInfo::DimVectorT (rank, 1 );
852+ }
831853 }
832854
833855 return AxisInfo (contiguity, divisibility, constancy, constantValue);
@@ -840,11 +862,6 @@ class LogicalOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
840862 using BinaryOpVisitorImpl<OpTy>::BinaryOpVisitorImpl;
841863
842864private:
843- int64_t getConstancy (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
844- int dim) override {
845- return gcd (lhs.getConstancy (dim), rhs.getConstancy (dim));
846- }
847-
848865 std::optional<int64_t > getConstantValue (OpTy op, const AxisInfo &lhs,
849866 const AxisInfo &rhs) override {
850867 if (lhs.getConstantValue ().has_value () &&
@@ -890,11 +907,6 @@ class ShLIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::ShLIOp> {
890907 return multiplyDivisor (lhsDivisibility, 1ll << shift);
891908 }
892909
893- int64_t getConstancy (arith::ShLIOp op, const AxisInfo &lhs,
894- const AxisInfo &rhs, int dim) override {
895- return gcd (lhs.getConstancy (dim), rhs.getConstancy (dim));
896- }
897-
898910 std::optional<int64_t > getConstantValue (arith::ShLIOp op, const AxisInfo &lhs,
899911 const AxisInfo &rhs) override {
900912 if (lhs.getConstantValue ().has_value () &&
@@ -932,11 +944,6 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
932944 return std::max<int64_t >(1 , lhsDivisibility / (int64_t (1 ) << shift));
933945 }
934946
935- int64_t getConstancy (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
936- int dim) override {
937- return gcd (lhs.getConstancy (dim), rhs.getConstancy (dim));
938- }
939-
940947 std::optional<int64_t > getConstantValue (OpTy op, const AxisInfo &lhs,
941948 const AxisInfo &rhs) override {
942949 if (lhs.getConstantValue ().has_value () &&
@@ -969,9 +976,15 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
969976 constantValue = {std::min (lhsInfo.getConstantValue ().value (),
970977 rhsInfo.getConstantValue ().value ())};
971978 }
979+ auto resTy = dyn_cast<RankedTensorType>(op.getType ());
980+ assert (resTy || rank == 1 );
981+ AxisInfo::DimVectorT constancy =
982+ resTy ? to_vector (resTy.getShape ()) : AxisInfo::DimVectorT (rank, 1 );
983+ AxisInfo::DimVectorT divisibility (
984+ rank, highestPowOf2Divisor<int64_t >(constantValue.value ()));
972985 return AxisInfo (/* knownContiguity=*/ AxisInfo::DimVectorT (rank, 1 ),
973- /* knownDivisibility=*/ AxisInfo::DimVectorT (rank, 1 ) ,
974- /* knownConstancy=*/ AxisInfo::DimVectorT (rank, 1 ) ,
986+ /* knownDivisibility=*/ divisibility ,
987+ /* knownConstancy=*/ constancy ,
975988 /* constantValue=*/ constantValue);
976989 } else {
977990 AxisInfo::DimVectorT contiguity, divisibility, constancy;
@@ -1029,11 +1042,11 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver,
10291042 // This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is
10301043 // in the process of a PartialConversion, where UnrealizedConversionCast
10311044 // may exist
1045+ visitors.append <UnrealizedConversionCastOpAxisInfoVisitor>();
10321046 visitors.append <CastOpAxisInfoVisitor<arith::ExtSIOp>,
10331047 CastOpAxisInfoVisitor<arith::ExtUIOp>,
10341048 CastOpAxisInfoVisitor<arith::TruncIOp>,
10351049 CastOpAxisInfoVisitor<triton::gpu::ConvertLayoutOp>,
1036- CastOpAxisInfoVisitor<mlir::UnrealizedConversionCastOp>,
10371050 CastOpAxisInfoVisitor<triton::BitcastOp>>();
10381051 visitors.append <MakeRangeOpAxisInfoVisitor>();
10391052 visitors.append <PoisonOpAxisInfoVisitor>();
@@ -1384,7 +1397,10 @@ void ModuleAxisInfoAnalysis::update(CallOpInterface callOp,
13841397 callee.setArgAttr (index, attrName, attr);
13851398 };
13861399 auto axisInfo = axisInfoMap->lookup (value);
1387- assert (axisInfo.getRank () == 1 && " only scalar arguments are supported" );
1400+ // Only scalar arguments are supported. Do not forward multi-dimensional
1401+ // AxisInfo to the callee.
1402+ if (axisInfo.getRank () != 1 )
1403+ continue ;
13881404 setAttrFn (" tt.contiguity" , axisInfo.getContiguity (0 ));
13891405 setAttrFn (" tt.divisibility" , axisInfo.getDivisibility (0 ));
13901406 setAttrFn (" tt.constancy" , axisInfo.getConstancy (0 ));
0 commit comments