diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp index 89668af378070..16229598b612a 100644 --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -610,28 +610,82 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed, const KnownBits &RHS) { // We don't see NSW even for sadd/ssub as we want to check if the result has // signed overflow. - KnownBits Res = - KnownBits::computeForAddSub(Add, /*NSW=*/false, /*NUW=*/false, LHS, RHS); - unsigned BitWidth = Res.getBitWidth(); - auto SignBitKnown = [&](const KnownBits &K) { - return K.Zero[BitWidth - 1] || K.One[BitWidth - 1]; - }; - std::optional Overflow; + unsigned BitWidth = LHS.getBitWidth(); + std::optional Overflow; + // Even if we can't entirely rule out overflow, we may be able to rule out + // overflow in one direction. This allows us to potentially keep some of the + // add/sub bits. I.e if we can't overflow in the positive direction we won't + // clamp to INT_MAX so we can keep low 0s from the add/sub result. + bool MayNegClamp = true; + bool MayPosClamp = true; if (Signed) { - // If we can actually detect overflow do so. Otherwise leave Overflow as - // nullopt (we assume it may have happened). - if (SignBitKnown(LHS) && SignBitKnown(RHS) && SignBitKnown(Res)) { + // Easy cases we can rule out any overflow. + if (Add && ((LHS.isNegative() && RHS.isNonNegative()) || + (LHS.isNonNegative() && RHS.isNegative()))) + Overflow = false; + else if (!Add && (((LHS.isNegative() && RHS.isNegative()) || + (LHS.isNonNegative() && RHS.isNonNegative())))) + Overflow = false; + else { + // Check if we may overflow. If we can't rule out overflow then check if + // we can rule out a direction at least. + KnownBits UnsignedLHS = LHS; + KnownBits UnsignedRHS = RHS; + // Get version of LHS/RHS with clearer signbit. This allows us to detect + // how the addition/subtraction might overflow into the signbit. Then + // using the actual known signbits of LHS/RHS, we can figure out which + // overflows are/aren't possible. + UnsignedLHS.One.clearSignBit(); + UnsignedLHS.Zero.setSignBit(); + UnsignedRHS.One.clearSignBit(); + UnsignedRHS.Zero.setSignBit(); + KnownBits Res = + KnownBits::computeForAddSub(Add, /*NSW=*/false, + /*NUW=*/false, UnsignedLHS, UnsignedRHS); if (Add) { - // sadd.sat - Overflow = (LHS.isNonNegative() == RHS.isNonNegative() && - Res.isNonNegative() != LHS.isNonNegative()); + if (Res.isNegative()) { + // Only overflow scenario is Pos + Pos. + MayNegClamp = false; + // Pos + Pos will overflow with extra signbit. + if (LHS.isNonNegative() && RHS.isNonNegative()) + Overflow = true; + } else if (Res.isNonNegative()) { + // Only overflow scenario is Neg + Neg + MayPosClamp = false; + // Neg + Neg will overflow without extra signbit. + if (LHS.isNegative() && RHS.isNegative()) + Overflow = true; + } + // We will never clamp to the opposite sign of N-bit result. + if (LHS.isNegative() || RHS.isNegative()) + MayPosClamp = false; + if (LHS.isNonNegative() || RHS.isNonNegative()) + MayNegClamp = false; } else { - // ssub.sat - Overflow = (LHS.isNonNegative() != RHS.isNonNegative() && - Res.isNonNegative() != LHS.isNonNegative()); + if (Res.isNegative()) { + // Only overflow scenario is Neg - Pos. + MayPosClamp = false; + // Neg - Pos will overflow with extra signbit. + if (LHS.isNegative() && RHS.isNonNegative()) + Overflow = true; + } else if (Res.isNonNegative()) { + // Only overflow scenario is Pos - Neg. + MayNegClamp = false; + // Pos - Neg will overflow without extra signbit. + if (LHS.isNonNegative() && RHS.isNegative()) + Overflow = true; + } + // We will never clamp to the opposite sign of N-bit result. + if (LHS.isNegative() || RHS.isNonNegative()) + MayPosClamp = false; + if (LHS.isNonNegative() || RHS.isNegative()) + MayNegClamp = false; } } + // If we have ruled out all clamping, we will never overflow. + if (!MayNegClamp && !MayPosClamp) + Overflow = false; } else if (Add) { // uadd.sat bool Of; @@ -656,52 +710,8 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed, } } - if (Signed) { - if (Add) { - if (LHS.isNonNegative() && RHS.isNonNegative()) { - // Pos + Pos -> Pos - Res.One.clearSignBit(); - Res.Zero.setSignBit(); - } - if (LHS.isNegative() && RHS.isNegative()) { - // Neg + Neg -> Neg - Res.One.setSignBit(); - Res.Zero.clearSignBit(); - } - } else { - if (LHS.isNegative() && RHS.isNonNegative()) { - // Neg - Pos -> Neg - Res.One.setSignBit(); - Res.Zero.clearSignBit(); - } else if (LHS.isNonNegative() && RHS.isNegative()) { - // Pos - Neg -> Pos - Res.One.clearSignBit(); - Res.Zero.setSignBit(); - } - } - } else { - // Add: Leading ones of either operand are preserved. - // Sub: Leading zeros of LHS and leading ones of RHS are preserved - // as leading zeros in the result. - unsigned LeadingKnown; - if (Add) - LeadingKnown = - std::max(LHS.countMinLeadingOnes(), RHS.countMinLeadingOnes()); - else - LeadingKnown = - std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingOnes()); - - // We select between the operation result and all-ones/zero - // respectively, so we can preserve known ones/zeros. - APInt Mask = APInt::getHighBitsSet(BitWidth, LeadingKnown); - if (Add) { - Res.One |= Mask; - Res.Zero &= ~Mask; - } else { - Res.Zero |= Mask; - Res.One &= ~Mask; - } - } + KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW=*/Signed, + /*NUW=*/!Signed, LHS, RHS); if (Overflow) { // We know whether or not we overflowed. @@ -714,7 +724,7 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed, APInt C; if (Signed) { // sadd.sat / ssub.sat - assert(SignBitKnown(LHS) && + assert(!LHS.isSignUnknown() && "We somehow know overflow without knowing input sign"); C = LHS.isNegative() ? APInt::getSignedMinValue(BitWidth) : APInt::getSignedMaxValue(BitWidth); @@ -735,8 +745,10 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed, if (Signed) { // sadd.sat/ssub.sat // We can keep our information about the sign bits. - Res.Zero.clearLowBits(BitWidth - 1); - Res.One.clearLowBits(BitWidth - 1); + if (MayPosClamp) + Res.Zero.clearLowBits(BitWidth - 1); + if (MayNegClamp) + Res.One.clearLowBits(BitWidth - 1); } else if (Add) { // uadd.sat // We need to clear all the known zeros as we can only use the leading ones. diff --git a/llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll b/llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll index c2926eaffa58c..f9618e1ddbc02 100644 --- a/llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll +++ b/llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll @@ -142,14 +142,7 @@ define i1 @ssub_sat_low_bits(i8 %x, i8 %y) { define i1 @ssub_sat_fail_may_overflow(i8 %x, i8 %y) { ; CHECK-LABEL: @ssub_sat_fail_may_overflow( -; CHECK-NEXT: [[XX:%.*]] = and i8 [[X:%.*]], 15 -; CHECK-NEXT: [[YY:%.*]] = and i8 [[Y:%.*]], 15 -; CHECK-NEXT: [[LHS:%.*]] = or i8 [[XX]], 1 -; CHECK-NEXT: [[RHS:%.*]] = and i8 [[YY]], -2 -; CHECK-NEXT: [[EXP:%.*]] = call i8 @llvm.ssub.sat.i8(i8 [[LHS]], i8 [[RHS]]) -; CHECK-NEXT: [[AND:%.*]] = and i8 [[EXP]], 1 -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[AND]], 0 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 false ; %xx = and i8 %x, 15 %yy = and i8 %y, 15 diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp index b16368de17648..ce0bf86e39dd7 100644 --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -383,26 +383,22 @@ TEST(KnownBitsTest, BinaryExhaustive) { "sadd_sat", KnownBits::sadd_sat, [](const APInt &N1, const APInt &N2) -> std::optional { return N1.sadd_sat(N2); - }, - /*CheckOptimality=*/false); + }); testBinaryOpExhaustive( "uadd_sat", KnownBits::uadd_sat, [](const APInt &N1, const APInt &N2) -> std::optional { return N1.uadd_sat(N2); - }, - /*CheckOptimality=*/false); + }); testBinaryOpExhaustive( "ssub_sat", KnownBits::ssub_sat, [](const APInt &N1, const APInt &N2) -> std::optional { return N1.ssub_sat(N2); - }, - /*CheckOptimality=*/false); + }); testBinaryOpExhaustive( "usub_sat", KnownBits::usub_sat, [](const APInt &N1, const APInt &N2) -> std::optional { return N1.usub_sat(N2); - }, - /*CheckOptimality=*/false); + }); testBinaryOpExhaustive( "shl", [](const KnownBits &Known1, const KnownBits &Known2) {