@@ -91,23 +91,26 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl<OpTy> {
9191 auto lhsInfo = operands[0 ]->getValue ();
9292 auto rhsInfo = operands[1 ]->getValue ();
9393 auto rank = lhsInfo.getRank ();
94+ assert (isa<RankedTensorType>(op.getType ()) ||
95+ rank == 1 && " Expected ranked tensor or scalar" );
9496 assert (operands.size () == 2 && " Expected two operands" );
97+ auto constantValue = getConstantValue (op, lhsInfo, rhsInfo);
98+ if (constantValue.has_value ()) {
99+ auto resTy = dyn_cast<RankedTensorType>(op.getType ());
100+ AxisInfo::DimVectorT constancy =
101+ resTy ? to_vector (resTy.getShape ()) : AxisInfo::DimVectorT (rank, 1 );
102+ AxisInfo::DimVectorT contiguity (rank, 1 );
103+ AxisInfo::DimVectorT divisibility (
104+ rank, highestPowOf2Divisor<int64_t >(constantValue.value ()));
105+ return AxisInfo (contiguity, divisibility, constancy, constantValue);
106+ }
95107 AxisInfo::DimVectorT contiguity;
96108 AxisInfo::DimVectorT divisibility;
97109 AxisInfo::DimVectorT constancy;
98- auto constantValue = getConstantValue (op, lhsInfo, rhsInfo);
99110 for (auto d = 0 ; d < rank; ++d) {
100- if (constantValue.has_value ()) {
101- contiguity.push_back (1 );
102- constancy.push_back (
103- std::max (lhsInfo.getConstancy (d), rhsInfo.getConstancy (d)));
104- divisibility.push_back (
105- highestPowOf2Divisor<int64_t >(constantValue.value ()));
106- } else {
107- contiguity.push_back (getContiguity (op, lhsInfo, rhsInfo, d));
108- constancy.push_back (getConstancy (op, lhsInfo, rhsInfo, d));
109- divisibility.push_back (getDivisibility (op, lhsInfo, rhsInfo, d));
110- }
111+ contiguity.push_back (getContiguity (op, lhsInfo, rhsInfo, d));
112+ constancy.push_back (getConstancy (op, lhsInfo, rhsInfo, d));
113+ divisibility.push_back (getDivisibility (op, lhsInfo, rhsInfo, d));
111114 }
112115 return AxisInfo (contiguity, divisibility, constancy, constantValue);
113116 }
@@ -125,9 +128,8 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl<OpTy> {
125128
126129 virtual int64_t getConstancy (OpTy op, const AxisInfo &lhs,
127130 const AxisInfo &rhs, int dim) {
128- return 1 ;
131+ return gcd (lhs. getConstancy (dim), rhs. getConstancy (dim)) ;
129132 }
130-
131133 virtual std::optional<int64_t > getConstantValue (OpTy op, const AxisInfo &lhs,
132134 const AxisInfo &rhs) {
133135 return {};
@@ -328,11 +330,6 @@ class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
328330 return gcd (lhs.getDivisibility (dim), rhsDivisibility);
329331 }
330332
331- int64_t getConstancy (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
332- int dim) override {
333- return gcd (lhs.getConstancy (dim), rhs.getConstancy (dim));
334- }
335-
336333 std::optional<int64_t > getConstantValue (OpTy op, const AxisInfo &lhs,
337334 const AxisInfo &rhs) override {
338335 if (lhs.getConstantValue ().has_value () &&
@@ -375,11 +372,6 @@ class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::MulIOp> {
375372 return std::max (lhsContiguity, rhsContiguity);
376373 }
377374
378- int64_t getConstancy (arith::MulIOp op, const AxisInfo &lhs,
379- const AxisInfo &rhs, int dim) override {
380- return gcd (lhs.getConstancy (dim), rhs.getConstancy (dim));
381- }
382-
383375 int64_t getDivisibility (arith::MulIOp op, const AxisInfo &lhs,
384376 const AxisInfo &rhs, int dim) override {
385377 auto lhsDivisibility = lhs.getDivisibility (dim);
@@ -399,9 +391,13 @@ class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::MulIOp> {
399391
400392 std::optional<int64_t > getConstantValue (arith::MulIOp op, const AxisInfo &lhs,
401393 const AxisInfo &rhs) override {
402- if (lhs.getConstantValue ().has_value () &&
403- rhs.getConstantValue ().has_value ())
404- return {lhs.getConstantValue ().value () * rhs.getConstantValue ().value ()};
394+ auto lhsConst = lhs.getConstantValue ();
395+ auto rhsConst = rhs.getConstantValue ();
396+ if (lhsConst.has_value () && rhsConst.has_value ())
397+ return {lhsConst.value () * rhsConst.value ()};
398+ if ((lhsConst.has_value () && lhsConst.value () == 0 ) ||
399+ (rhsConst.has_value () && rhsConst.value () == 0 ))
400+ return 0 ;
405401 return {};
406402 }
407403};
@@ -424,12 +420,11 @@ class DivOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
424420 int64_t getConstancy (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
425421 int dim) override {
426422 auto resTy = dyn_cast<RankedTensorType>(op.getType ());
423+ auto constancy = BinaryOpVisitorImpl<OpTy>::getConstancy (op, lhs, rhs, dim);
427424 if (!resTy)
428- return BinaryOpVisitorImpl<OpTy>:: getConstancy (op, lhs, rhs, dim) ;
425+ return constancy ;
429426 auto shape = resTy.getShape ();
430- // Case 1: both lhs and rhs are constants.
431- auto constancy = gcd (lhs.getConstancy (dim), rhs.getConstancy (dim));
432- // Case 2: lhs contiguous, rhs constant.
427+ // Case: lhs contiguous, rhs constant.
433428 // lhs: d_lhs * k, d_lhs * k + 1, ..., d_lhs * k + n
434429 // rhs: d_rhs * p, d_rhs * p, ..., d_rhs * p
435430 // lhs / rhs = d_lhs * k / (d_rhs * p), (d_lhs * k + 1) / (d_rhs * p),
@@ -526,15 +521,15 @@ class RemOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
526521
527522 int64_t getConstancy (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
528523 int dim) override {
524+ auto constancy = BinaryOpVisitorImpl<OpTy>::getConstancy (op, lhs, rhs, dim);
529525 auto resTy = dyn_cast<RankedTensorType>(op.getType ());
530526 if (!resTy)
531- return BinaryOpVisitorImpl<OpTy>::getConstancy (op, lhs, rhs, dim);
532- auto shape = resTy.getShape ();
533- // lhs % 1 = 0
534- return rhs.getConstantValue ().has_value () &&
535- rhs.getConstantValue ().value () == 1
536- ? shape[dim]
537- : gcd (lhs.getConstancy (dim), rhs.getConstancy (dim));
527+ return constancy;
528+ // Case: lhs % 1 = 0
529+ if (rhs.getConstantValue ().has_value () &&
530+ rhs.getConstantValue ().value () == 1 )
531+ return resTy.getDimSize (dim);
532+ return constancy;
538533 }
539534
540535 std::optional<int64_t > getConstantValue (OpTy op, const AxisInfo &lhs,
@@ -689,7 +684,7 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
689684 int64_t constHint = 1 ;
690685 if (lhsInfo.getConstantValue ().has_value () &&
691686 rhsInfo.getConstantValue ().has_value ()) {
692- constHint = lhsInfo. getConstancy (d) ;
687+ constHint = shape[d] ;
693688 constantValue =
694689 compare (getPredicate (op), lhsInfo.getConstantValue ().value (),
695690 rhsInfo.getConstantValue ().value ())
@@ -848,6 +843,13 @@ class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
848843 rhsInfo.getConstantValue ().has_value () &&
849844 lhsInfo.getConstantValue () == rhsInfo.getConstantValue ())
850845 constantValue = lhsInfo.getConstantValue ();
846+
847+ if (constantValue.has_value ()) {
848+ auto resTy = dyn_cast<RankedTensorType>(op.getType ());
849+ assert (resTy || rank == 1 );
850+ constancy =
851+ resTy ? to_vector (resTy.getShape ()) : AxisInfo::DimVectorT (rank, 1 );
852+ }
851853 }
852854
853855 return AxisInfo (contiguity, divisibility, constancy, constantValue);
@@ -860,11 +862,6 @@ class LogicalOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
860862 using BinaryOpVisitorImpl<OpTy>::BinaryOpVisitorImpl;
861863
862864private:
863- int64_t getConstancy (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
864- int dim) override {
865- return gcd (lhs.getConstancy (dim), rhs.getConstancy (dim));
866- }
867-
868865 std::optional<int64_t > getConstantValue (OpTy op, const AxisInfo &lhs,
869866 const AxisInfo &rhs) override {
870867 if (lhs.getConstantValue ().has_value () &&
@@ -910,11 +907,6 @@ class ShLIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::ShLIOp> {
910907 return multiplyDivisor (lhsDivisibility, 1ll << shift);
911908 }
912909
913- int64_t getConstancy (arith::ShLIOp op, const AxisInfo &lhs,
914- const AxisInfo &rhs, int dim) override {
915- return gcd (lhs.getConstancy (dim), rhs.getConstancy (dim));
916- }
917-
918910 std::optional<int64_t > getConstantValue (arith::ShLIOp op, const AxisInfo &lhs,
919911 const AxisInfo &rhs) override {
920912 if (lhs.getConstantValue ().has_value () &&
@@ -952,11 +944,6 @@ class ShROpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
952944 return std::max<int64_t >(1 , lhsDivisibility / (int64_t (1 ) << shift));
953945 }
954946
955- int64_t getConstancy (OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
956- int dim) override {
957- return gcd (lhs.getConstancy (dim), rhs.getConstancy (dim));
958- }
959-
960947 std::optional<int64_t > getConstantValue (OpTy op, const AxisInfo &lhs,
961948 const AxisInfo &rhs) override {
962949 if (lhs.getConstantValue ().has_value () &&
@@ -989,9 +976,15 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
989976 constantValue = {std::min (lhsInfo.getConstantValue ().value (),
990977 rhsInfo.getConstantValue ().value ())};
991978 }
979+ auto resTy = dyn_cast<RankedTensorType>(op.getType ());
980+ assert (resTy || rank == 1 );
981+ AxisInfo::DimVectorT constancy =
982+ resTy ? to_vector (resTy.getShape ()) : AxisInfo::DimVectorT (rank, 1 );
983+ AxisInfo::DimVectorT divisibility (
984+ rank, highestPowOf2Divisor<int64_t >(constantValue.value ()));
992985 return AxisInfo (/* knownContiguity=*/ AxisInfo::DimVectorT (rank, 1 ),
993- /* knownDivisibility=*/ AxisInfo::DimVectorT (rank, 1 ) ,
994- /* knownConstancy=*/ AxisInfo::DimVectorT (rank, 1 ) ,
986+ /* knownDivisibility=*/ divisibility ,
987+ /* knownConstancy=*/ constancy ,
995988 /* constantValue=*/ constantValue);
996989 } else {
997990 AxisInfo::DimVectorT contiguity, divisibility, constancy;
0 commit comments