Skip to content

Conversation

@jayfoad
Copy link
Contributor

@jayfoad jayfoad commented Oct 1, 2024

Rewrite the signed functions in terms of the unsigned ones which are
already optimal.

Rewrite the signed functions in terms of the unsigned ones which are
already optimal.
@llvmbot
Copy link
Member

llvmbot commented Oct 1, 2024

@llvm/pr-subscribers-llvm-support

Author: Jay Foad (jayfoad)

Changes

Rewrite the signed functions in terms of the unsigned ones which are
already optimal.


Full diff: https://github.com/llvm/llvm-project/pull/110688.diff

3 Files Affected:

  • (modified) llvm/include/llvm/Support/KnownBits.h (+3)
  • (modified) llvm/lib/Support/KnownBits.cpp (+17-22)
  • (modified) llvm/unittests/Support/KnownBitsTest.cpp (+3-4)
diff --git a/llvm/include/llvm/Support/KnownBits.h b/llvm/include/llvm/Support/KnownBits.h
index e4ec202f36aae0..a4b554fa2a0b72 100644
--- a/llvm/include/llvm/Support/KnownBits.h
+++ b/llvm/include/llvm/Support/KnownBits.h
@@ -29,6 +29,9 @@ struct KnownBits {
   KnownBits(APInt Zero, APInt One)
       : Zero(std::move(Zero)), One(std::move(One)) {}
 
+  // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF]
+  static KnownBits flipSignBit(const KnownBits &Val);
+
 public:
   // Default construct Zero and One.
   KnownBits() = default;
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 6863c5c0af5dca..a7801aa950cad3 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -18,6 +18,15 @@
 
 using namespace llvm;
 
+KnownBits KnownBits::flipSignBit(const KnownBits &Val) {
+  unsigned SignBitPosition = Val.getBitWidth() - 1;
+  APInt Zero = Val.Zero;
+  APInt One = Val.One;
+  Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
+  One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
+  return KnownBits(Zero, One);
+}
+
 static KnownBits computeForAddCarry(const KnownBits &LHS, const KnownBits &RHS,
                                     bool CarryZero, bool CarryOne) {
 
@@ -200,16 +209,7 @@ KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) {
 }
 
 KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) {
-  // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF]
-  auto Flip = [](const KnownBits &Val) {
-    unsigned SignBitPosition = Val.getBitWidth() - 1;
-    APInt Zero = Val.Zero;
-    APInt One = Val.One;
-    Zero.setBitVal(SignBitPosition, Val.One[SignBitPosition]);
-    One.setBitVal(SignBitPosition, Val.Zero[SignBitPosition]);
-    return KnownBits(Zero, One);
-  };
-  return Flip(umax(Flip(LHS), Flip(RHS)));
+  return flipSignBit(umax(flipSignBit(LHS), flipSignBit(RHS)));
 }
 
 KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
@@ -763,11 +763,10 @@ KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) {
   return computeForSatAddSub(/*Add*/ false, /*Signed*/ false, LHS, RHS);
 }
 
-static KnownBits avgCompute(KnownBits LHS, KnownBits RHS, bool IsCeil,
-                            bool IsSigned) {
+static KnownBits avgComputeU(KnownBits LHS, KnownBits RHS, bool IsCeil) {
   unsigned BitWidth = LHS.getBitWidth();
-  LHS = IsSigned ? LHS.sext(BitWidth + 1) : LHS.zext(BitWidth + 1);
-  RHS = IsSigned ? RHS.sext(BitWidth + 1) : RHS.zext(BitWidth + 1);
+  LHS = LHS.zext(BitWidth + 1);
+  RHS = RHS.zext(BitWidth + 1);
   LHS =
       computeForAddCarry(LHS, RHS, /*CarryZero*/ !IsCeil, /*CarryOne*/ IsCeil);
   LHS = LHS.extractBits(BitWidth, 1);
@@ -775,23 +774,19 @@ static KnownBits avgCompute(KnownBits LHS, KnownBits RHS, bool IsCeil,
 }
 
 KnownBits KnownBits::avgFloorS(const KnownBits &LHS, const KnownBits &RHS) {
-  return avgCompute(LHS, RHS, /* IsCeil */ false,
-                    /* IsSigned */ true);
+  return flipSignBit(avgFloorU(flipSignBit(LHS), flipSignBit(RHS)));
 }
 
 KnownBits KnownBits::avgFloorU(const KnownBits &LHS, const KnownBits &RHS) {
-  return avgCompute(LHS, RHS, /* IsCeil */ false,
-                    /* IsSigned */ false);
+  return avgComputeU(LHS, RHS, /* IsCeil */ false);
 }
 
 KnownBits KnownBits::avgCeilS(const KnownBits &LHS, const KnownBits &RHS) {
-  return avgCompute(LHS, RHS, /* IsCeil */ true,
-                    /* IsSigned */ true);
+  return flipSignBit(avgCeilU(flipSignBit(LHS), flipSignBit(RHS)));
 }
 
 KnownBits KnownBits::avgCeilU(const KnownBits &LHS, const KnownBits &RHS) {
-  return avgCompute(LHS, RHS, /* IsCeil */ true,
-                    /* IsSigned */ false);
+  return avgComputeU(LHS, RHS, /* IsCeil */ true);
 }
 
 KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index b701757aed5eb9..551c1a8107494b 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -521,16 +521,15 @@ TEST(KnownBitsTest, BinaryExhaustive) {
       [](const APInt &N1, const APInt &N2) { return APIntOps::mulhu(N1, N2); },
       /*CheckOptimality=*/false);
 
-  testBinaryOpExhaustive("avgFloorS", KnownBits::avgFloorS, APIntOps::avgFloorS,
-                         /*CheckOptimality=*/false);
+  testBinaryOpExhaustive("avgFloorS", KnownBits::avgFloorS,
+                         APIntOps::avgFloorS);
 
   testBinaryOpExhaustive("avgFloorU", KnownBits::avgFloorU,
                          APIntOps::avgFloorU);
 
   testBinaryOpExhaustive("avgCeilU", KnownBits::avgCeilU, APIntOps::avgCeilU);
 
-  testBinaryOpExhaustive("avgCeilS", KnownBits::avgCeilS, APIntOps::avgCeilS,
-                         /*CheckOptimality=*/false);
+  testBinaryOpExhaustive("avgCeilS", KnownBits::avgCeilS, APIntOps::avgCeilS);
 }
 
 TEST(KnownBitsTest, UnaryExhaustive) {

@jayfoad jayfoad requested review from RKSimon and goldsteinn October 1, 2024 15:28
@jayfoad
Copy link
Contributor Author

jayfoad commented Oct 1, 2024

This is an alternative to #110329. I think it's simpler. I don't know if it's any slower or faster.

KnownBits KnownBits::avgCeilU(const KnownBits &LHS, const KnownBits &RHS) {
return avgCompute(LHS, RHS, /* IsCeil */ true,
/* IsSigned */ false);
return avgComputeU(LHS, RHS, /* IsCeil */ true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, /*IsCeil=*/ for the comments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

@goldsteinn goldsteinn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, this is much cleaner

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - cheers

@jayfoad jayfoad merged commit 5cabf15 into llvm:main Oct 1, 2024
4 checks passed
@jayfoad jayfoad deleted the knownbits-avgs-optimal branch October 1, 2024 18:34
Sterling-Augustine pushed a commit to Sterling-Augustine/llvm-project that referenced this pull request Oct 3, 2024
Rewrite the signed functions in terms of the unsigned ones which are
already optimal.
@jayfoad jayfoad restored the knownbits-avgs-optimal branch January 14, 2025 08:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants