Skip to content

Commit d037382

Browse files
committed
[KnownBits] Make avg{Ceil,Floor}S optimal
All we where missing was the signbit if we knew the incoming signbit of either LHS or RHS. Since the base addition in the average is with an extra bit width it cannot overflow, we figure out the result sign based on the magnitude of the input. If the negative component has a larger magnitude the result is negative and vice versa for the positive case.
1 parent bbdca53 commit d037382

File tree

2 files changed

+46
-18
lines changed

2 files changed

+46
-18
lines changed

llvm/lib/Support/KnownBits.cpp

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -766,32 +766,61 @@ KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) {
766766
static KnownBits avgCompute(KnownBits LHS, KnownBits RHS, bool IsCeil,
767767
bool IsSigned) {
768768
unsigned BitWidth = LHS.getBitWidth();
769-
LHS = IsSigned ? LHS.sext(BitWidth + 1) : LHS.zext(BitWidth + 1);
770-
RHS = IsSigned ? RHS.sext(BitWidth + 1) : RHS.zext(BitWidth + 1);
771-
LHS =
772-
computeForAddCarry(LHS, RHS, /*CarryZero*/ !IsCeil, /*CarryOne*/ IsCeil);
773-
LHS = LHS.extractBits(BitWidth, 1);
774-
return LHS;
769+
KnownBits ExtLHS = IsSigned ? LHS.sext(BitWidth + 1) : LHS.zext(BitWidth + 1);
770+
KnownBits ExtRHS = IsSigned ? RHS.sext(BitWidth + 1) : RHS.zext(BitWidth + 1);
771+
KnownBits Res = computeForAddCarry(ExtLHS, ExtRHS, /*CarryZero=*/!IsCeil,
772+
/*CarryOne=*/IsCeil);
773+
Res = Res.extractBits(BitWidth, 1);
774+
775+
// If we have only 1 known signbit between LHS/RHS we can try to figure
776+
// out result signbit.
777+
// NB: If we know both signbits `computeForAddCarry` gets the optimal result
778+
// already.
779+
if (IsSigned && Res.isSignUnknown() &&
780+
LHS.isSignUnknown() != RHS.isSignUnknown()) {
781+
if (LHS.isSignUnknown())
782+
std::swap(LHS, RHS);
783+
KnownBits UnsignedLHS = LHS;
784+
KnownBits UnsignedRHS = RHS;
785+
UnsignedLHS.One.clearSignBit();
786+
UnsignedLHS.Zero.setSignBit();
787+
UnsignedRHS.One.clearSignBit();
788+
UnsignedRHS.Zero.setSignBit();
789+
KnownBits ResOf =
790+
computeForAddCarry(UnsignedLHS, UnsignedRHS, /*CarryZero=*/!IsCeil,
791+
/*CarryOne=*/IsCeil);
792+
// Assuming no overflow (which is the case since we extend the addition when
793+
// taking the average):
794+
// Neg + Neg -> Neg
795+
// Neg + Pos -> Neg if the signbit doesn't overflow
796+
if (LHS.isNegative() && ResOf.isNonNegative())
797+
Res.makeNegative();
798+
// Pos + Pos -> Pos
799+
// Pos + Neg -> Pos if the signbit does overflow
800+
else if (LHS.isNonNegative() && ResOf.isNegative())
801+
Res.makeNonNegative();
802+
}
803+
return Res;
775804
}
776805

777806
KnownBits KnownBits::avgFloorS(const KnownBits &LHS, const KnownBits &RHS) {
778-
return avgCompute(LHS, RHS, /* IsCeil */ false,
779-
/* IsSigned */ true);
807+
return avgCompute(LHS, RHS, /* IsCeil=*/false,
808+
/* IsSigned=*/true);
780809
}
781810

782811
KnownBits KnownBits::avgFloorU(const KnownBits &LHS, const KnownBits &RHS) {
783-
return avgCompute(LHS, RHS, /* IsCeil */ false,
784-
/* IsSigned */ false);
812+
return avgCompute(LHS, RHS, /* IsCeil=*/false,
813+
/* IsSigned=*/false);
785814
}
786815

787816
KnownBits KnownBits::avgCeilS(const KnownBits &LHS, const KnownBits &RHS) {
788-
return avgCompute(LHS, RHS, /* IsCeil */ true,
789-
/* IsSigned */ true);
817+
return avgCompute(LHS, RHS, /* IsCeil=*/true,
818+
/* IsSigned=*/true);
790819
}
791820

792821
KnownBits KnownBits::avgCeilU(const KnownBits &LHS, const KnownBits &RHS) {
793-
return avgCompute(LHS, RHS, /* IsCeil */ true,
794-
/* IsSigned */ false);
822+
return avgCompute(LHS, RHS, /* IsCeil=*/true,
823+
/* IsSigned=*/false);
795824
}
796825

797826
KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,

llvm/unittests/Support/KnownBitsTest.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -521,16 +521,15 @@ TEST(KnownBitsTest, BinaryExhaustive) {
521521
[](const APInt &N1, const APInt &N2) { return APIntOps::mulhu(N1, N2); },
522522
/*CheckOptimality=*/false);
523523

524-
testBinaryOpExhaustive("avgFloorS", KnownBits::avgFloorS, APIntOps::avgFloorS,
525-
/*CheckOptimality=*/false);
524+
testBinaryOpExhaustive("avgFloorS", KnownBits::avgFloorS,
525+
APIntOps::avgFloorS);
526526

527527
testBinaryOpExhaustive("avgFloorU", KnownBits::avgFloorU,
528528
APIntOps::avgFloorU);
529529

530530
testBinaryOpExhaustive("avgCeilU", KnownBits::avgCeilU, APIntOps::avgCeilU);
531531

532-
testBinaryOpExhaustive("avgCeilS", KnownBits::avgCeilS, APIntOps::avgCeilS,
533-
/*CheckOptimality=*/false);
532+
testBinaryOpExhaustive("avgCeilS", KnownBits::avgCeilS, APIntOps::avgCeilS);
534533
}
535534

536535
TEST(KnownBitsTest, UnaryExhaustive) {

0 commit comments

Comments
 (0)