Skip to content

Commit a5d26f2

Browse files
committed
[SelectionDAG]: Add more cases for UDIV, SDIV, SRA, and SRL
1 parent e79a8ff commit a5d26f2

File tree

2 files changed

+47
-4
lines changed

2 files changed

+47
-4
lines changed

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5443,19 +5443,59 @@ bool SelectionDAG::isKnownNeverZero(SDValue Op, unsigned Depth) const {
54435443
if (ValKnown.isNegative())
54445444
return true;
54455445
// If max shift cnt of known ones is non-zero, result is non-zero.
5446-
APInt MaxCnt = computeKnownBits(Op.getOperand(1), Depth + 1).getMaxValue();
5446+
const KnownBits Shift = computeKnownBits(Op.getOperand(1), Depth + 1);
5447+
APInt MaxCnt = Shift.getMaxValue();
54475448
if (MaxCnt.ult(ValKnown.getBitWidth()) &&
54485449
!ValKnown.One.lshr(MaxCnt).isZero())
54495450
return true;
5451+
// Similar to udiv but we try to see if we can turn it into a division
5452+
const KnownBits One =
5453+
KnownBits::makeConstant(APInt(ValKnown.getBitWidth(), 1));
5454+
5455+
std::optional<bool> uge =
5456+
KnownBits::uge(ValKnown, KnownBits::shl(One, Shift));
5457+
if (uge && *uge)
5458+
return true;
54505459
break;
54515460
}
5452-
case ISD::UDIV:
5453-
case ISD::SDIV:
5461+
case ISD::UDIV: {
5462+
if (Op->getFlags().hasExact())
5463+
return isKnownNeverZero(Op.getOperand(0), Depth + 1);
5464+
KnownBits Op0 = computeKnownBits(Op.getOperand(0), Depth + 1);
5465+
KnownBits Op1 = computeKnownBits(Op.getOperand(1), Depth + 1);
5466+
// True if Op0 u>= Op1
5467+
5468+
std::optional<bool> uge = KnownBits::uge(Op0, Op1);
5469+
if (uge && *uge)
5470+
return true;
5471+
break;
5472+
}
5473+
case ISD::SDIV: {
54545474
// div exact can only produce a zero if the dividend is zero.
5455-
// TODO: For udiv this is also true if Op1 u<= Op0
54565475
if (Op->getFlags().hasExact())
54575476
return isKnownNeverZero(Op.getOperand(0), Depth + 1);
5477+
KnownBits Op0 = computeKnownBits(Op.getOperand(0), Depth + 1);
5478+
KnownBits Op1 = computeKnownBits(Op.getOperand(1), Depth + 1);
5479+
if (Op0.isNegative() && Op1.isStrictlyPositive())
5480+
return true;
5481+
5482+
if (Op0.isStrictlyPositive() && Op1.isNegative())
5483+
return true;
5484+
5485+
// For negative numbers, the comparison is reversed. Op0 <= Op1
5486+
if (Op0.isNegative() && Op1.isNegative()) {
5487+
std::optional<bool> sle = KnownBits::sle(Op0, Op1);
5488+
if (sle && *sle)
5489+
return true;
5490+
}
5491+
5492+
if (Op0.isStrictlyPositive() && Op1.isStrictlyPositive()) {
5493+
std::optional<bool> uge = KnownBits::uge(Op0, Op1);
5494+
if (uge && *uge)
5495+
return true;
5496+
}
54585497
break;
5498+
}
54595499

54605500
case ISD::ADD:
54615501
if (Op->getFlags().hasNoUnsignedWrap())

llvm/lib/Support/KnownBits.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,9 @@ KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS,
10091009

10101010
Known.Zero.setHighBits(LeadZ);
10111011
Known = divComputeLowBit(Known, LHS, RHS, Exact);
1012+
std::optional<bool> uge = KnownBits::uge(LHS, RHS);
1013+
if (uge && *uge)
1014+
Known.makeGE(APInt(BitWidth, 1));
10121015

10131016
assert(!Known.hasConflict() && "Bad Output");
10141017
return Known;

0 commit comments

Comments
 (0)