Skip to content

Commit 360ed74

Browse files
committed
[SelectionDAG]: Add more cases for UDIV, SDIV, SRA, and SRL
1 parent f11a726 commit 360ed74

File tree

2 files changed

+103
-6
lines changed

2 files changed

+103
-6
lines changed

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 97 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5580,27 +5580,98 @@ bool SelectionDAG::isKnownNeverZero(SDValue Op, unsigned Depth) const {
55805580
if (ValKnown.isNegative())
55815581
return true;
55825582
// If max shift cnt of known ones is non-zero, result is non-zero.
5583-
APInt MaxCnt = computeKnownBits(Op.getOperand(1), Depth + 1).getMaxValue();
5583+
const KnownBits Shift = computeKnownBits(Op.getOperand(1), Depth + 1);
5584+
APInt MaxCnt = Shift.getMaxValue();
55845585
if (MaxCnt.ult(ValKnown.getBitWidth()) &&
55855586
!ValKnown.One.lshr(MaxCnt).isZero())
55865587
return true;
5588+
// Similar to udiv but we try to see if we can turn it into a division
5589+
const KnownBits One =
5590+
KnownBits::makeConstant(APInt(ValKnown.getBitWidth(), 1));
5591+
5592+
std::optional<bool> uge =
5593+
KnownBits::uge(ValKnown, KnownBits::shl(One, Shift));
5594+
if (uge && *uge)
5595+
return true;
55875596
break;
55885597
}
5589-
case ISD::UDIV:
5590-
case ISD::SDIV:
5598+
case ISD::UDIV: {
5599+
if (Op->getFlags().hasExact())
5600+
return isKnownNeverZero(Op.getOperand(0), Depth + 1);
5601+
KnownBits Op0 = computeKnownBits(Op.getOperand(0), Depth + 1);
5602+
KnownBits Op1 = computeKnownBits(Op.getOperand(1), Depth + 1);
5603+
// True if Op0 u>= Op1
5604+
5605+
std::optional<bool> uge = KnownBits::uge(Op0, Op1);
5606+
if (uge && *uge)
5607+
return true;
5608+
break;
5609+
}
5610+
case ISD::SDIV: {
55915611
// div exact can only produce a zero if the dividend is zero.
5592-
// TODO: For udiv this is also true if Op1 u<= Op0
55935612
if (Op->getFlags().hasExact())
55945613
return isKnownNeverZero(Op.getOperand(0), Depth + 1);
5614+
KnownBits Op0 = computeKnownBits(Op.getOperand(0), Depth + 1);
5615+
KnownBits Op1 = computeKnownBits(Op.getOperand(1), Depth + 1);
5616+
if (Op0.isNegative() && Op1.isStrictlyPositive())
5617+
return true;
5618+
5619+
if (Op0.isStrictlyPositive() && Op1.isNegative())
5620+
return true;
5621+
5622+
// For negative numbers, the comparison is reversed. Op0 <= Op1
5623+
if (Op0.isNegative() && Op1.isNegative()) {
5624+
std::optional<bool> sle = KnownBits::sle(Op0, Op1);
5625+
if (sle && *sle)
5626+
return true;
5627+
}
5628+
5629+
if (Op0.isStrictlyPositive() && Op1.isStrictlyPositive()) {
5630+
std::optional<bool> uge = KnownBits::uge(Op0, Op1);
5631+
if (uge && *uge)
5632+
return true;
5633+
}
55955634
break;
5635+
}
55965636

5597-
case ISD::ADD:
5637+
case ISD::ADD: {
55985638
if (Op->getFlags().hasNoUnsignedWrap())
55995639
if (isKnownNeverZero(Op.getOperand(1), Depth + 1) ||
56005640
isKnownNeverZero(Op.getOperand(0), Depth + 1))
56015641
return true;
5642+
5643+
KnownBits Op0 = computeKnownBits(Op.getOperand(0), Depth + 1);
5644+
KnownBits Op1 = computeKnownBits(Op.getOperand(1), Depth + 1);
5645+
5646+
// If X and Y are both non-negative (as signed values) then their sum is not
5647+
// zero unless both X and Y are zero.
5648+
if (Op0.isNonNegative() && Op1.isNonNegative())
5649+
if (isKnownNeverZero(Op.getOperand(1), Depth + 1) ||
5650+
isKnownNeverZero(Op.getOperand(0), Depth + 1))
5651+
return true;
5652+
// If X and Y are both negative (as signed values) then their sum is not
5653+
// zero unless both X and Y equal INT_MIN.
5654+
if (Op0.isNegative() && Op1.isNegative()) {
5655+
APInt Mask = APInt::getSignedMaxValue(Op0.getBitWidth());
5656+
// The sign bit of X is set. If some other bit is set then X is not equal
5657+
// to INT_MIN.
5658+
if (Op0.One.intersects(Mask))
5659+
return true;
5660+
// The sign bit of Y is set. If some other bit is set then Y is not equal
5661+
// to INT_MIN.
5662+
if (Op1.One.intersects(Mask))
5663+
return true;
5664+
}
5665+
5666+
if (KnownBits::computeForAddSub(
5667+
/*Add=*/true, Op->getFlags().hasNoSignedWrap(),
5668+
Op->getFlags().hasNoUnsignedWrap(), Op0, Op1)
5669+
.isNonZero())
5670+
return true;
5671+
56025672
// TODO: There are a lot more cases we can prove for add.
56035673
break;
5674+
}
56045675

56055676
case ISD::SUB: {
56065677
if (isNullConstant(Op.getOperand(0)))
@@ -5612,12 +5683,32 @@ bool SelectionDAG::isKnownNeverZero(SDValue Op, unsigned Depth) const {
56125683
return ne && *ne;
56135684
}
56145685

5615-
case ISD::MUL:
5686+
case ISD::MUL: {
56165687
if (Op->getFlags().hasNoSignedWrap() || Op->getFlags().hasNoUnsignedWrap())
56175688
if (isKnownNeverZero(Op.getOperand(1), Depth + 1) &&
56185689
isKnownNeverZero(Op.getOperand(0), Depth + 1))
56195690
return true;
5691+
5692+
KnownBits XKnown = computeKnownBits(Op.getOperand(0), Depth + 1);
5693+
if (XKnown.One[0])
5694+
if (isKnownNeverZero(Op.getOperand(1), Depth + 1))
5695+
return true;
5696+
5697+
KnownBits YKnown = computeKnownBits(Op.getOperand(1), Depth + 1);
5698+
if (YKnown.One[0])
5699+
if (XKnown.isNonZero() || isKnownNeverZero(Op.getOperand(0), Depth + 1))
5700+
return true;
5701+
5702+
// If there exists any subset of X (sX) and subset of Y (sY) s.t sX * sY is
5703+
// non-zero, then X * Y is non-zero. We can find sX and sY by just taking
5704+
// the lowest known One of X and Y. If they are non-zero, the result
5705+
// must be non-zero. We can check if LSB(X) * LSB(Y) != 0 by doing
5706+
// X.CountLeadingZeros + Y.CountLeadingZeros < BitWidth.
5707+
if (XKnown.countMaxTrailingZeros() + YKnown.countMaxTrailingZeros() <
5708+
XKnown.getBitWidth())
5709+
return true;
56205710
break;
5711+
}
56215712

56225713
case ISD::ZERO_EXTEND:
56235714
case ISD::SIGN_EXTEND:

llvm/lib/Support/KnownBits.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,9 @@ KnownBits KnownBits::sdiv(const KnownBits &LHS, const KnownBits &RHS,
969969
Res = (Num.isMinSignedValue() && Denom.isAllOnes())
970970
? APInt::getSignedMaxValue(BitWidth)
971971
: Num.sdiv(Denom);
972+
std::optional<bool> sle = KnownBits::sle(LHS, RHS);
973+
if (sle && *sle)
974+
Known.makeGE(APInt(BitWidth, 1));
972975
} else if (LHS.isNegative() && RHS.isNonNegative()) {
973976
// Result is negative if Exact OR -LHS u>= RHS.
974977
if (Exact || (-LHS.getSignedMaxValue()).uge(RHS.getSignedMaxValue())) {
@@ -1022,6 +1025,9 @@ KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS,
10221025

10231026
Known.Zero.setHighBits(LeadZ);
10241027
Known = divComputeLowBit(Known, LHS, RHS, Exact);
1028+
std::optional<bool> uge = KnownBits::uge(LHS, RHS);
1029+
if (uge && *uge)
1030+
Known.makeGE(APInt(BitWidth, 1));
10251031

10261032
return Known;
10271033
}

0 commit comments

Comments
 (0)