@@ -5580,27 +5580,98 @@ bool SelectionDAG::isKnownNeverZero(SDValue Op, unsigned Depth) const {
55805580 if (ValKnown.isNegative ())
55815581 return true ;
55825582 // If max shift cnt of known ones is non-zero, result is non-zero.
5583- APInt MaxCnt = computeKnownBits (Op.getOperand (1 ), Depth + 1 ).getMaxValue ();
5583+ const KnownBits Shift = computeKnownBits (Op.getOperand (1 ), Depth + 1 );
5584+ APInt MaxCnt = Shift.getMaxValue ();
55845585 if (MaxCnt.ult (ValKnown.getBitWidth ()) &&
55855586 !ValKnown.One .lshr (MaxCnt).isZero ())
55865587 return true ;
5588+ // Similar to udiv but we try to see if we can turn it into a division
5589+ const KnownBits One =
5590+ KnownBits::makeConstant (APInt (ValKnown.getBitWidth (), 1 ));
5591+
5592+ std::optional<bool > uge =
5593+ KnownBits::uge (ValKnown, KnownBits::shl (One, Shift));
5594+ if (uge && *uge)
5595+ return true ;
55875596 break ;
55885597 }
5589- case ISD::UDIV:
5590- case ISD::SDIV:
5598+ case ISD::UDIV: {
5599+ if (Op->getFlags ().hasExact ())
5600+ return isKnownNeverZero (Op.getOperand (0 ), Depth + 1 );
5601+ KnownBits Op0 = computeKnownBits (Op.getOperand (0 ), Depth + 1 );
5602+ KnownBits Op1 = computeKnownBits (Op.getOperand (1 ), Depth + 1 );
5603+ // True if Op0 u>= Op1
5604+
5605+ std::optional<bool > uge = KnownBits::uge (Op0, Op1);
5606+ if (uge && *uge)
5607+ return true ;
5608+ break ;
5609+ }
5610+ case ISD::SDIV: {
55915611 // div exact can only produce a zero if the dividend is zero.
5592- // TODO: For udiv this is also true if Op1 u<= Op0
55935612 if (Op->getFlags ().hasExact ())
55945613 return isKnownNeverZero (Op.getOperand (0 ), Depth + 1 );
5614+ KnownBits Op0 = computeKnownBits (Op.getOperand (0 ), Depth + 1 );
5615+ KnownBits Op1 = computeKnownBits (Op.getOperand (1 ), Depth + 1 );
5616+ if (Op0.isNegative () && Op1.isStrictlyPositive ())
5617+ return true ;
5618+
5619+ if (Op0.isStrictlyPositive () && Op1.isNegative ())
5620+ return true ;
5621+
5622+ // For negative numbers, the comparison is reversed. Op0 <= Op1
5623+ if (Op0.isNegative () && Op1.isNegative ()) {
5624+ std::optional<bool > sle = KnownBits::sle (Op0, Op1);
5625+ if (sle && *sle)
5626+ return true ;
5627+ }
5628+
5629+ if (Op0.isStrictlyPositive () && Op1.isStrictlyPositive ()) {
5630+ std::optional<bool > uge = KnownBits::uge (Op0, Op1);
5631+ if (uge && *uge)
5632+ return true ;
5633+ }
55955634 break ;
5635+ }
55965636
5597- case ISD::ADD:
5637+ case ISD::ADD: {
55985638 if (Op->getFlags ().hasNoUnsignedWrap ())
55995639 if (isKnownNeverZero (Op.getOperand (1 ), Depth + 1 ) ||
56005640 isKnownNeverZero (Op.getOperand (0 ), Depth + 1 ))
56015641 return true ;
5642+
5643+ KnownBits Op0 = computeKnownBits (Op.getOperand (0 ), Depth + 1 );
5644+ KnownBits Op1 = computeKnownBits (Op.getOperand (1 ), Depth + 1 );
5645+
5646+ // If X and Y are both non-negative (as signed values) then their sum is not
5647+ // zero unless both X and Y are zero.
5648+ if (Op0.isNonNegative () && Op1.isNonNegative ())
5649+ if (isKnownNeverZero (Op.getOperand (1 ), Depth + 1 ) ||
5650+ isKnownNeverZero (Op.getOperand (0 ), Depth + 1 ))
5651+ return true ;
5652+ // If X and Y are both negative (as signed values) then their sum is not
5653+ // zero unless both X and Y equal INT_MIN.
5654+ if (Op0.isNegative () && Op1.isNegative ()) {
5655+ APInt Mask = APInt::getSignedMaxValue (Op0.getBitWidth ());
5656+ // The sign bit of X is set. If some other bit is set then X is not equal
5657+ // to INT_MIN.
5658+ if (Op0.One .intersects (Mask))
5659+ return true ;
5660+ // The sign bit of Y is set. If some other bit is set then Y is not equal
5661+ // to INT_MIN.
5662+ if (Op1.One .intersects (Mask))
5663+ return true ;
5664+ }
5665+
5666+ if (KnownBits::computeForAddSub (
5667+ /* Add=*/ true , Op->getFlags ().hasNoSignedWrap (),
5668+ Op->getFlags ().hasNoUnsignedWrap (), Op0, Op1)
5669+ .isNonZero ())
5670+ return true ;
5671+
56025672 // TODO: There are a lot more cases we can prove for add.
56035673 break ;
5674+ }
56045675
56055676 case ISD::SUB: {
56065677 if (isNullConstant (Op.getOperand (0 )))
@@ -5612,12 +5683,32 @@ bool SelectionDAG::isKnownNeverZero(SDValue Op, unsigned Depth) const {
56125683 return ne && *ne;
56135684 }
56145685
5615- case ISD::MUL:
5686+ case ISD::MUL: {
56165687 if (Op->getFlags ().hasNoSignedWrap () || Op->getFlags ().hasNoUnsignedWrap ())
56175688 if (isKnownNeverZero (Op.getOperand (1 ), Depth + 1 ) &&
56185689 isKnownNeverZero (Op.getOperand (0 ), Depth + 1 ))
56195690 return true ;
5691+
5692+ KnownBits XKnown = computeKnownBits (Op.getOperand (0 ), Depth + 1 );
5693+ if (XKnown.One [0 ])
5694+ if (isKnownNeverZero (Op.getOperand (1 ), Depth + 1 ))
5695+ return true ;
5696+
5697+ KnownBits YKnown = computeKnownBits (Op.getOperand (1 ), Depth + 1 );
5698+ if (YKnown.One [0 ])
5699+ if (XKnown.isNonZero () || isKnownNeverZero (Op.getOperand (0 ), Depth + 1 ))
5700+ return true ;
5701+
5702+ // If there exists any subset of X (sX) and subset of Y (sY) s.t sX * sY is
5703+ // non-zero, then X * Y is non-zero. We can find sX and sY by just taking
5704+ // the lowest known One of X and Y. If they are non-zero, the result
5705+ // must be non-zero. We can check if LSB(X) * LSB(Y) != 0 by doing
5706+ // X.CountLeadingZeros + Y.CountLeadingZeros < BitWidth.
5707+ if (XKnown.countMaxTrailingZeros () + YKnown.countMaxTrailingZeros () <
5708+ XKnown.getBitWidth ())
5709+ return true ;
56205710 break ;
5711+ }
56215712
56225713 case ISD::ZERO_EXTEND:
56235714 case ISD::SIGN_EXTEND:
0 commit comments