From d037382b86f0fd05d165657857cc76a3a4a28615 Mon Sep 17 00:00:00 2001 From: Noah Goldstein Date: Fri, 27 Sep 2024 15:39:03 -0500 Subject: [PATCH] [KnownBits] Make `avg{Ceil,Floor}S` optimal All we where missing was the signbit if we knew the incoming signbit of either LHS or RHS. Since the base addition in the average is with an extra bit width it cannot overflow, we figure out the result sign based on the magnitude of the input. If the negative component has a larger magnitude the result is negative and vice versa for the positive case. --- llvm/lib/Support/KnownBits.cpp | 57 ++++++++++++++++++------ llvm/unittests/Support/KnownBitsTest.cpp | 7 ++- 2 files changed, 46 insertions(+), 18 deletions(-) diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp index 6863c5c0af5dc..5b5e6df53d6a1 100644 --- a/llvm/lib/Support/KnownBits.cpp +++ b/llvm/lib/Support/KnownBits.cpp @@ -766,32 +766,61 @@ KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) { static KnownBits avgCompute(KnownBits LHS, KnownBits RHS, bool IsCeil, bool IsSigned) { unsigned BitWidth = LHS.getBitWidth(); - LHS = IsSigned ? LHS.sext(BitWidth + 1) : LHS.zext(BitWidth + 1); - RHS = IsSigned ? RHS.sext(BitWidth + 1) : RHS.zext(BitWidth + 1); - LHS = - computeForAddCarry(LHS, RHS, /*CarryZero*/ !IsCeil, /*CarryOne*/ IsCeil); - LHS = LHS.extractBits(BitWidth, 1); - return LHS; + KnownBits ExtLHS = IsSigned ? LHS.sext(BitWidth + 1) : LHS.zext(BitWidth + 1); + KnownBits ExtRHS = IsSigned ? RHS.sext(BitWidth + 1) : RHS.zext(BitWidth + 1); + KnownBits Res = computeForAddCarry(ExtLHS, ExtRHS, /*CarryZero=*/!IsCeil, + /*CarryOne=*/IsCeil); + Res = Res.extractBits(BitWidth, 1); + + // If we have only 1 known signbit between LHS/RHS we can try to figure + // out result signbit. + // NB: If we know both signbits `computeForAddCarry` gets the optimal result + // already. + if (IsSigned && Res.isSignUnknown() && + LHS.isSignUnknown() != RHS.isSignUnknown()) { + if (LHS.isSignUnknown()) + std::swap(LHS, RHS); + KnownBits UnsignedLHS = LHS; + KnownBits UnsignedRHS = RHS; + UnsignedLHS.One.clearSignBit(); + UnsignedLHS.Zero.setSignBit(); + UnsignedRHS.One.clearSignBit(); + UnsignedRHS.Zero.setSignBit(); + KnownBits ResOf = + computeForAddCarry(UnsignedLHS, UnsignedRHS, /*CarryZero=*/!IsCeil, + /*CarryOne=*/IsCeil); + // Assuming no overflow (which is the case since we extend the addition when + // taking the average): + // Neg + Neg -> Neg + // Neg + Pos -> Neg if the signbit doesn't overflow + if (LHS.isNegative() && ResOf.isNonNegative()) + Res.makeNegative(); + // Pos + Pos -> Pos + // Pos + Neg -> Pos if the signbit does overflow + else if (LHS.isNonNegative() && ResOf.isNegative()) + Res.makeNonNegative(); + } + return Res; } KnownBits KnownBits::avgFloorS(const KnownBits &LHS, const KnownBits &RHS) { - return avgCompute(LHS, RHS, /* IsCeil */ false, - /* IsSigned */ true); + return avgCompute(LHS, RHS, /* IsCeil=*/false, + /* IsSigned=*/true); } KnownBits KnownBits::avgFloorU(const KnownBits &LHS, const KnownBits &RHS) { - return avgCompute(LHS, RHS, /* IsCeil */ false, - /* IsSigned */ false); + return avgCompute(LHS, RHS, /* IsCeil=*/false, + /* IsSigned=*/false); } KnownBits KnownBits::avgCeilS(const KnownBits &LHS, const KnownBits &RHS) { - return avgCompute(LHS, RHS, /* IsCeil */ true, - /* IsSigned */ true); + return avgCompute(LHS, RHS, /* IsCeil=*/true, + /* IsSigned=*/true); } KnownBits KnownBits::avgCeilU(const KnownBits &LHS, const KnownBits &RHS) { - return avgCompute(LHS, RHS, /* IsCeil */ true, - /* IsSigned */ false); + return avgCompute(LHS, RHS, /* IsCeil=*/true, + /* IsSigned=*/false); } KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS, diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp index b701757aed5eb..551c1a8107494 100644 --- a/llvm/unittests/Support/KnownBitsTest.cpp +++ b/llvm/unittests/Support/KnownBitsTest.cpp @@ -521,16 +521,15 @@ TEST(KnownBitsTest, BinaryExhaustive) { [](const APInt &N1, const APInt &N2) { return APIntOps::mulhu(N1, N2); }, /*CheckOptimality=*/false); - testBinaryOpExhaustive("avgFloorS", KnownBits::avgFloorS, APIntOps::avgFloorS, - /*CheckOptimality=*/false); + testBinaryOpExhaustive("avgFloorS", KnownBits::avgFloorS, + APIntOps::avgFloorS); testBinaryOpExhaustive("avgFloorU", KnownBits::avgFloorU, APIntOps::avgFloorU); testBinaryOpExhaustive("avgCeilU", KnownBits::avgCeilU, APIntOps::avgCeilU); - testBinaryOpExhaustive("avgCeilS", KnownBits::avgCeilS, APIntOps::avgCeilS, - /*CheckOptimality=*/false); + testBinaryOpExhaustive("avgCeilS", KnownBits::avgCeilS, APIntOps::avgCeilS); } TEST(KnownBitsTest, UnaryExhaustive) {