diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp index 6863c5c0af5dc..5b5e6df53d6a1 100644 --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -766,32 +766,61 @@ KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) { static KnownBits avgCompute(KnownBits LHS, KnownBits RHS, bool IsCeil, bool IsSigned) { unsigned BitWidth = LHS.getBitWidth(); - LHS = IsSigned ? LHS.sext(BitWidth + 1) : LHS.zext(BitWidth + 1); - RHS = IsSigned ? RHS.sext(BitWidth + 1) : RHS.zext(BitWidth + 1); - LHS = - computeForAddCarry(LHS, RHS, /*CarryZero*/ !IsCeil, /*CarryOne*/ IsCeil); - LHS = LHS.extractBits(BitWidth, 1); - return LHS; + KnownBits ExtLHS = IsSigned ? LHS.sext(BitWidth + 1) : LHS.zext(BitWidth + 1); + KnownBits ExtRHS = IsSigned ? RHS.sext(BitWidth + 1) : RHS.zext(BitWidth + 1); + KnownBits Res = computeForAddCarry(ExtLHS, ExtRHS, /*CarryZero=*/!IsCeil, + /*CarryOne=*/IsCeil); + Res = Res.extractBits(BitWidth, 1); + + // If we have only 1 known signbit between LHS/RHS we can try to figure + // out result signbit. + // NB: If we know both signbits `computeForAddCarry` gets the optimal result + // already. + if (IsSigned && Res.isSignUnknown() && + LHS.isSignUnknown() != RHS.isSignUnknown()) { + if (LHS.isSignUnknown()) + std::swap(LHS, RHS); + KnownBits UnsignedLHS = LHS; + KnownBits UnsignedRHS = RHS; + UnsignedLHS.One.clearSignBit(); + UnsignedLHS.Zero.setSignBit(); + UnsignedRHS.One.clearSignBit(); + UnsignedRHS.Zero.setSignBit(); + KnownBits ResOf = + computeForAddCarry(UnsignedLHS, UnsignedRHS, /*CarryZero=*/!IsCeil, + /*CarryOne=*/IsCeil); + // Assuming no overflow (which is the case since we extend the addition when + // taking the average): + // Neg + Neg -> Neg + // Neg + Pos -> Neg if the signbit doesn't overflow + if (LHS.isNegative() && ResOf.isNonNegative()) + Res.makeNegative(); + // Pos + Pos -> Pos + // Pos + Neg -> Pos if the signbit does overflow + else if (LHS.isNonNegative() && ResOf.isNegative()) + Res.makeNonNegative(); + } + return Res; } KnownBits KnownBits::avgFloorS(const KnownBits &LHS, const KnownBits &RHS) { - return avgCompute(LHS, RHS, /* IsCeil */ false, - /* IsSigned */ true); + return avgCompute(LHS, RHS, /* IsCeil=*/false, + /* IsSigned=*/true); } KnownBits KnownBits::avgFloorU(const KnownBits &LHS, const KnownBits &RHS) { - return avgCompute(LHS, RHS, /* IsCeil */ false, - /* IsSigned */ false); + return avgCompute(LHS, RHS, /* IsCeil=*/false, + /* IsSigned=*/false); } KnownBits KnownBits::avgCeilS(const KnownBits &LHS, const KnownBits &RHS) { - return avgCompute(LHS, RHS, /* IsCeil */ true, - /* IsSigned */ true); + return avgCompute(LHS, RHS, /* IsCeil=*/true, + /* IsSigned=*/true); } KnownBits KnownBits::avgCeilU(const KnownBits &LHS, const KnownBits &RHS) { - return avgCompute(LHS, RHS, /* IsCeil */ true, - /* IsSigned */ false); + return avgCompute(LHS, RHS, /* IsCeil=*/true, + /* IsSigned=*/false); } KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS, diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp index b701757aed5eb..551c1a8107494 100644 --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -521,16 +521,15 @@ TEST(KnownBitsTest, BinaryExhaustive) { [](const APInt &N1, const APInt &N2) { return APIntOps::mulhu(N1, N2); }, /*CheckOptimality=*/false); - testBinaryOpExhaustive("avgFloorS", KnownBits::avgFloorS, APIntOps::avgFloorS, - /*CheckOptimality=*/false); + testBinaryOpExhaustive("avgFloorS", KnownBits::avgFloorS, + APIntOps::avgFloorS); testBinaryOpExhaustive("avgFloorU", KnownBits::avgFloorU, APIntOps::avgFloorU); testBinaryOpExhaustive("avgCeilU", KnownBits::avgCeilU, APIntOps::avgCeilU); - testBinaryOpExhaustive("avgCeilS", KnownBits::avgCeilS, APIntOps::avgCeilS, - /*CheckOptimality=*/false); + testBinaryOpExhaustive("avgCeilS", KnownBits::avgCeilS, APIntOps::avgCeilS); } TEST(KnownBitsTest, UnaryExhaustive) {