Skip to content

Commit 294266d

Browse files
committed
[SelectionDAG]: Add more cases for UDIV, SDIV, SRA, and SRL
These cases were ported from ValueTracking
1 parent b72e307 commit 294266d

File tree

1 file changed

+87
-6
lines changed

1 file changed

+87
-6
lines changed

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5580,27 +5580,88 @@ bool SelectionDAG::isKnownNeverZero(SDValue Op, unsigned Depth) const {
55805580
if (ValKnown.isNegative())
55815581
return true;
55825582
// If max shift cnt of known ones is non-zero, result is non-zero.
5583-
APInt MaxCnt = computeKnownBits(Op.getOperand(1), Depth + 1).getMaxValue();
5583+
const KnownBits Shift = computeKnownBits(Op.getOperand(1), Depth + 1);
5584+
APInt MaxCnt = Shift.getMaxValue();
55845585
if (MaxCnt.ult(ValKnown.getBitWidth()) &&
55855586
!ValKnown.One.lshr(MaxCnt).isZero())
55865587
return true;
5588+
// Similar to udiv but we try to see if we can turn it into a division
5589+
const KnownBits One =
5590+
KnownBits::makeConstant(APInt(ValKnown.getBitWidth(), 1));
5591+
5592+
std::optional<bool> uge =
5593+
KnownBits::uge(ValKnown, KnownBits::shl(One, Shift));
5594+
if (uge && *uge)
5595+
return true;
55875596
break;
55885597
}
5589-
case ISD::UDIV:
5590-
case ISD::SDIV:
5598+
case ISD::UDIV: {
5599+
if (Op->getFlags().hasExact())
5600+
return isKnownNeverZero(Op.getOperand(0), Depth + 1);
5601+
KnownBits Op0 = computeKnownBits(Op.getOperand(0), Depth + 1);
5602+
KnownBits Op1 = computeKnownBits(Op.getOperand(1), Depth + 1);
5603+
// True if Op0 u>= Op1
5604+
5605+
std::optional<bool> Uge = KnownBits::uge(Op0, Op1);
5606+
if (Uge && *Uge)
5607+
return true;
5608+
break;
5609+
}
5610+
case ISD::SDIV: {
55915611
// div exact can only produce a zero if the dividend is zero.
5592-
// TODO: For udiv this is also true if Op1 u<= Op0
55935612
if (Op->getFlags().hasExact())
55945613
return isKnownNeverZero(Op.getOperand(0), Depth + 1);
5614+
KnownBits Op0 = computeKnownBits(Op.getOperand(0), Depth + 1);
5615+
KnownBits Op1 = computeKnownBits(Op.getOperand(1), Depth + 1);
5616+
5617+
// For signed division need to compare abs value of the operands.
5618+
Op0 = Op0.abs(/*IntMinIsPoison*/ false);
5619+
Op1 = Op1.abs(/*IntMinIsPoison*/ false);
5620+
5621+
std::optional<bool> Uge = KnownBits::uge(Op0, Op1);
5622+
if (Uge && *Uge)
5623+
return true;
55955624
break;
5625+
}
55965626

5597-
case ISD::ADD:
5627+
case ISD::ADD: {
55985628
if (Op->getFlags().hasNoUnsignedWrap())
55995629
if (isKnownNeverZero(Op.getOperand(1), Depth + 1) ||
56005630
isKnownNeverZero(Op.getOperand(0), Depth + 1))
56015631
return true;
5632+
5633+
KnownBits Op0 = computeKnownBits(Op.getOperand(0), Depth + 1);
5634+
KnownBits Op1 = computeKnownBits(Op.getOperand(1), Depth + 1);
5635+
5636+
// If X and Y are both non-negative (as signed values) then their sum is not
5637+
// zero unless both X and Y are zero.
5638+
if (Op0.isNonNegative() && Op1.isNonNegative())
5639+
if (isKnownNeverZero(Op.getOperand(1), Depth + 1) ||
5640+
isKnownNeverZero(Op.getOperand(0), Depth + 1))
5641+
return true;
5642+
// If X and Y are both negative (as signed values) then their sum is not
5643+
// zero unless both X and Y equal INT_MIN.
5644+
if (Op0.isNegative() && Op1.isNegative()) {
5645+
APInt Mask = APInt::getSignedMaxValue(Op0.getBitWidth());
5646+
// The sign bit of X is set. If some other bit is set then X is not equal
5647+
// to INT_MIN.
5648+
if (Op0.One.intersects(Mask))
5649+
return true;
5650+
// The sign bit of Y is set. If some other bit is set then Y is not equal
5651+
// to INT_MIN.
5652+
if (Op1.One.intersects(Mask))
5653+
return true;
5654+
}
5655+
5656+
if (KnownBits::computeForAddSub(
5657+
/*Add=*/true, Op->getFlags().hasNoSignedWrap(),
5658+
Op->getFlags().hasNoUnsignedWrap(), Op0, Op1)
5659+
.isNonZero())
5660+
return true;
5661+
56025662
// TODO: There are a lot more cases we can prove for add.
56035663
break;
5664+
}
56045665

56055666
case ISD::SUB: {
56065667
if (isNullConstant(Op.getOperand(0)))
@@ -5612,12 +5673,32 @@ bool SelectionDAG::isKnownNeverZero(SDValue Op, unsigned Depth) const {
56125673
return ne && *ne;
56135674
}
56145675

5615-
case ISD::MUL:
5676+
case ISD::MUL: {
56165677
if (Op->getFlags().hasNoSignedWrap() || Op->getFlags().hasNoUnsignedWrap())
56175678
if (isKnownNeverZero(Op.getOperand(1), Depth + 1) &&
56185679
isKnownNeverZero(Op.getOperand(0), Depth + 1))
56195680
return true;
5681+
5682+
KnownBits XKnown = computeKnownBits(Op.getOperand(0), Depth + 1);
5683+
if (XKnown.One[0])
5684+
if (isKnownNeverZero(Op.getOperand(1), Depth + 1))
5685+
return true;
5686+
5687+
KnownBits YKnown = computeKnownBits(Op.getOperand(1), Depth + 1);
5688+
if (YKnown.One[0])
5689+
if (isKnownNeverZero(Op.getOperand(0), Depth + 1))
5690+
return true;
5691+
5692+
// If there exists any subset of X (sX) and subset of Y (sY) s.t sX * sY is
5693+
// non-zero, then X * Y is non-zero. We can find sX and sY by just taking
5694+
// the lowest known One of X and Y. If they are non-zero, the result
5695+
// must be non-zero. We can check if LSB(X) * LSB(Y) != 0 by doing
5696+
// X.CountLeadingZeros + Y.CountLeadingZeros < BitWidth.
5697+
if (XKnown.countMaxTrailingZeros() + YKnown.countMaxTrailingZeros() <
5698+
XKnown.getBitWidth())
5699+
return true;
56205700
break;
5701+
}
56215702

56225703
case ISD::ZERO_EXTEND:
56235704
case ISD::SIGN_EXTEND:

0 commit comments

Comments
 (0)