@@ -5443,19 +5443,59 @@ bool SelectionDAG::isKnownNeverZero(SDValue Op, unsigned Depth) const {
54435443 if (ValKnown.isNegative ())
54445444 return true ;
54455445 // If max shift cnt of known ones is non-zero, result is non-zero.
5446- APInt MaxCnt = computeKnownBits (Op.getOperand (1 ), Depth + 1 ).getMaxValue ();
5446+ const KnownBits Shift = computeKnownBits (Op.getOperand (1 ), Depth + 1 );
5447+ APInt MaxCnt = Shift.getMaxValue ();
54475448 if (MaxCnt.ult (ValKnown.getBitWidth ()) &&
54485449 !ValKnown.One .lshr (MaxCnt).isZero ())
54495450 return true ;
5451+ // Similar to udiv but we try to see if we can turn it into a division
5452+ const KnownBits One =
5453+ KnownBits::makeConstant (APInt (ValKnown.getBitWidth (), 1 ));
5454+
5455+ std::optional<bool > uge =
5456+ KnownBits::uge (ValKnown, KnownBits::shl (One, Shift));
5457+ if (uge && *uge)
5458+ return true ;
54505459 break ;
54515460 }
5452- case ISD::UDIV:
5453- case ISD::SDIV:
5461+ case ISD::UDIV: {
5462+ if (Op->getFlags ().hasExact ())
5463+ return isKnownNeverZero (Op.getOperand (0 ), Depth + 1 );
5464+ KnownBits Op0 = computeKnownBits (Op.getOperand (0 ), Depth + 1 );
5465+ KnownBits Op1 = computeKnownBits (Op.getOperand (1 ), Depth + 1 );
5466+ // True if Op0 u>= Op1
5467+
5468+ std::optional<bool > uge = KnownBits::uge (Op0, Op1);
5469+ if (uge && *uge)
5470+ return true ;
5471+ break ;
5472+ }
5473+ case ISD::SDIV: {
54545474 // div exact can only produce a zero if the dividend is zero.
5455- // TODO: For udiv this is also true if Op1 u<= Op0
54565475 if (Op->getFlags ().hasExact ())
54575476 return isKnownNeverZero (Op.getOperand (0 ), Depth + 1 );
5477+ KnownBits Op0 = computeKnownBits (Op.getOperand (0 ), Depth + 1 );
5478+ KnownBits Op1 = computeKnownBits (Op.getOperand (1 ), Depth + 1 );
5479+ if (Op0.isNegative () && Op1.isStrictlyPositive ())
5480+ return true ;
5481+
5482+ if (Op0.isStrictlyPositive () && Op1.isNegative ())
5483+ return true ;
5484+
5485+ // For negative numbers, the comparison is reversed. Op0 <= Op1
5486+ if (Op0.isNegative () && Op1.isNegative ()) {
5487+ std::optional<bool > sle = KnownBits::sle (Op0, Op1);
5488+ if (sle && *sle)
5489+ return true ;
5490+ }
5491+
5492+ if (Op0.isStrictlyPositive () && Op1.isStrictlyPositive ()) {
5493+ std::optional<bool > uge = KnownBits::uge (Op0, Op1);
5494+ if (uge && *uge)
5495+ return true ;
5496+ }
54585497 break ;
5498+ }
54595499
54605500 case ISD::ADD:
54615501 if (Op->getFlags ().hasNoUnsignedWrap ())
0 commit comments