Skip to content

Conversation

@goldsteinn
Copy link
Contributor

@goldsteinn goldsteinn commented Sep 27, 2024

  • [KnownBits] Make avg{Ceil,Floor}S optimal

All we where missing was the signbit if we knew the incoming signbit
of either LHS or RHS.

Since the base addition in the average is with an extra bit width it
cannot overflow, we figure out the result sign based on the magnitude
of the input. If the negative component has a larger magnitude the
result is negative and vice versa for the positive case.

@llvmbot llvmbot added llvm:support llvm:analysis Includes value tracking, cost tables and constant folding labels Sep 27, 2024
@llvmbot
Copy link
Member

llvmbot commented Sep 27, 2024

@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-llvm-support

Author: None (goldsteinn)

Changes
  • [KnownBits] Mark avg{Ceil,Floor}U as optimal in exhaustive test; NFC
  • [KnownBits] Make avg{Ceil,Floor}S optimal
  • [KnownBits] Make {s,u}{add,sub}_sat optimal

Full diff: https://github.com/llvm/llvm-project/pull/110329.diff

3 Files Affected:

  • (modified) llvm/lib/Support/KnownBits.cpp (+113-72)
  • (modified) llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll (+1-8)
  • (modified) llvm/unittests/Support/KnownBitsTest.cpp (+12-18)
diff --git a/llvm/lib/Support/KnownBits.cpp b/llvm/lib/Support/KnownBits.cpp
index 6863c5c0af5dca..89e4b108b83e55 100644
--- a/llvm/lib/Support/KnownBits.cpp
+++ b/llvm/lib/Support/KnownBits.cpp
@@ -610,28 +610,78 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
                                      const KnownBits &RHS) {
   // We don't see NSW even for sadd/ssub as we want to check if the result has
   // signed overflow.
-  KnownBits Res =
-      KnownBits::computeForAddSub(Add, /*NSW=*/false, /*NUW=*/false, LHS, RHS);
-  unsigned BitWidth = Res.getBitWidth();
-  auto SignBitKnown = [&](const KnownBits &K) {
-    return K.Zero[BitWidth - 1] || K.One[BitWidth - 1];
-  };
-  std::optional<bool> Overflow;
+  unsigned BitWidth = LHS.getBitWidth();
 
+  std::optional<bool> Overflow;
+  // Even if we can't entirely rule out overflow, we may be able to rule out
+  // overflow in one direction. This allows us to potentially keep some of the
+  // add/sub bits. I.e if we can't overflow in the positive direction we won't
+  // clamp to INT_MAX so we can keep low 0s from the add/sub result.
+  bool MayNegClamp = true;
+  bool MayPosClamp = true;
   if (Signed) {
-    // If we can actually detect overflow do so. Otherwise leave Overflow as
-    // nullopt (we assume it may have happened).
-    if (SignBitKnown(LHS) && SignBitKnown(RHS) && SignBitKnown(Res)) {
+    // Easy cases we can rule out any overflow.
+    if (Add && ((LHS.isNegative() && RHS.isNonNegative()) ||
+                (LHS.isNonNegative() && RHS.isNegative())))
+      Overflow = false;
+    else if (!Add && (((LHS.isNegative() && RHS.isNegative()) ||
+                       (LHS.isNonNegative() && RHS.isNonNegative()))))
+      Overflow = false;
+    else {
+      // Check if we may overflow. If we can't rule out overflow then check if
+      // we can rule out a direction at least.
+      KnownBits UnsignedLHS = LHS;
+      KnownBits UnsignedRHS = RHS;
+      UnsignedLHS.One.clearSignBit();
+      UnsignedLHS.Zero.setSignBit();
+      UnsignedRHS.One.clearSignBit();
+      UnsignedRHS.Zero.setSignBit();
+      KnownBits Res =
+          KnownBits::computeForAddSub(Add, /*NSW=*/false,
+                                      /*NUW=*/false, UnsignedLHS, UnsignedRHS);
       if (Add) {
-        // sadd.sat
-        Overflow = (LHS.isNonNegative() == RHS.isNonNegative() &&
-                    Res.isNonNegative() != LHS.isNonNegative());
+        if (Res.isNegative()) {
+          // Only overflow scenario is Pos + Pos.
+          MayNegClamp = false;
+          // Pos + Pos will overflow with extra signbit.
+          if (LHS.isNonNegative() && RHS.isNonNegative())
+            Overflow = true;
+        } else if (Res.isNonNegative()) {
+          // Only overflow scenario is Neg + Neg
+          MayPosClamp = false;
+          // Neg + Neg will overflow without extra signbit.
+          if (LHS.isNegative() && RHS.isNegative())
+            Overflow = true;
+        }
+        // We will never clamp to the opposite sign of N-bit result.
+        if (LHS.isNegative() || RHS.isNegative())
+          MayPosClamp = false;
+        if (LHS.isNonNegative() || RHS.isNonNegative())
+          MayNegClamp = false;
       } else {
-        // ssub.sat
-        Overflow = (LHS.isNonNegative() != RHS.isNonNegative() &&
-                    Res.isNonNegative() != LHS.isNonNegative());
+        if (Res.isNegative()) {
+          // Only overflow scenario is Neg - Pos.
+          MayPosClamp = false;
+          // Neg - Pos will overflow with extra signbit.
+          if (LHS.isNegative() && RHS.isNonNegative())
+            Overflow = true;
+        } else if (Res.isNonNegative()) {
+          // Only overflow scenario is Pos - Neg.
+          MayNegClamp = false;
+          // Pos - Neg will overflow without extra signbit.
+          if (LHS.isNonNegative() && RHS.isNegative())
+            Overflow = true;
+        }
+        // We will never clamp to the opposite sign of N-bit result.
+        if (LHS.isNegative() || RHS.isNonNegative())
+          MayPosClamp = false;
+        if (LHS.isNonNegative() || RHS.isNegative())
+          MayNegClamp = false;
       }
     }
+    // If we have ruled out all clamping, we will never overflow.
+    if (!MayNegClamp && !MayPosClamp)
+      Overflow = false;
   } else if (Add) {
     // uadd.sat
     bool Of;
@@ -656,52 +706,8 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
     }
   }
 
-  if (Signed) {
-    if (Add) {
-      if (LHS.isNonNegative() && RHS.isNonNegative()) {
-        // Pos + Pos -> Pos
-        Res.One.clearSignBit();
-        Res.Zero.setSignBit();
-      }
-      if (LHS.isNegative() && RHS.isNegative()) {
-        // Neg + Neg -> Neg
-        Res.One.setSignBit();
-        Res.Zero.clearSignBit();
-      }
-    } else {
-      if (LHS.isNegative() && RHS.isNonNegative()) {
-        // Neg - Pos -> Neg
-        Res.One.setSignBit();
-        Res.Zero.clearSignBit();
-      } else if (LHS.isNonNegative() && RHS.isNegative()) {
-        // Pos - Neg -> Pos
-        Res.One.clearSignBit();
-        Res.Zero.setSignBit();
-      }
-    }
-  } else {
-    // Add: Leading ones of either operand are preserved.
-    // Sub: Leading zeros of LHS and leading ones of RHS are preserved
-    // as leading zeros in the result.
-    unsigned LeadingKnown;
-    if (Add)
-      LeadingKnown =
-          std::max(LHS.countMinLeadingOnes(), RHS.countMinLeadingOnes());
-    else
-      LeadingKnown =
-          std::max(LHS.countMinLeadingZeros(), RHS.countMinLeadingOnes());
-
-    // We select between the operation result and all-ones/zero
-    // respectively, so we can preserve known ones/zeros.
-    APInt Mask = APInt::getHighBitsSet(BitWidth, LeadingKnown);
-    if (Add) {
-      Res.One |= Mask;
-      Res.Zero &= ~Mask;
-    } else {
-      Res.Zero |= Mask;
-      Res.One &= ~Mask;
-    }
-  }
+  KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW=*/Signed,
+                                              /*NUW*/ !Signed, LHS, RHS);
 
   if (Overflow) {
     // We know whether or not we overflowed.
@@ -714,8 +720,9 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
     APInt C;
     if (Signed) {
       // sadd.sat / ssub.sat
-      assert(SignBitKnown(LHS) &&
-             "We somehow know overflow without knowing input sign");
+      assert(LHS.isNegative() ||
+             LHS.isNonNegative() &&
+                 "We somehow know overflow without knowing input sign");
       C = LHS.isNegative() ? APInt::getSignedMinValue(BitWidth)
                            : APInt::getSignedMaxValue(BitWidth);
     } else if (Add) {
@@ -735,8 +742,10 @@ static KnownBits computeForSatAddSub(bool Add, bool Signed,
   if (Signed) {
     // sadd.sat/ssub.sat
     // We can keep our information about the sign bits.
-    Res.Zero.clearLowBits(BitWidth - 1);
-    Res.One.clearLowBits(BitWidth - 1);
+    if (MayPosClamp)
+      Res.Zero.clearLowBits(BitWidth - 1);
+    if (MayNegClamp)
+      Res.One.clearLowBits(BitWidth - 1);
   } else if (Add) {
     // uadd.sat
     // We need to clear all the known zeros as we can only use the leading ones.
@@ -766,12 +775,44 @@ KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) {
 static KnownBits avgCompute(KnownBits LHS, KnownBits RHS, bool IsCeil,
                             bool IsSigned) {
   unsigned BitWidth = LHS.getBitWidth();
-  LHS = IsSigned ? LHS.sext(BitWidth + 1) : LHS.zext(BitWidth + 1);
-  RHS = IsSigned ? RHS.sext(BitWidth + 1) : RHS.zext(BitWidth + 1);
-  LHS =
-      computeForAddCarry(LHS, RHS, /*CarryZero*/ !IsCeil, /*CarryOne*/ IsCeil);
-  LHS = LHS.extractBits(BitWidth, 1);
-  return LHS;
+  KnownBits ExtLHS = IsSigned ? LHS.sext(BitWidth + 1) : LHS.zext(BitWidth + 1);
+  KnownBits ExtRHS = IsSigned ? RHS.sext(BitWidth + 1) : RHS.zext(BitWidth + 1);
+  KnownBits Res = computeForAddCarry(ExtLHS, ExtRHS, /*CarryZero*/ !IsCeil,
+                                     /*CarryOne*/ IsCeil);
+  Res = Res.extractBits(BitWidth, 1);
+
+  auto SignBitKnown = [BitWidth](KnownBits KB) {
+    return KB.One[BitWidth - 1] || KB.Zero[BitWidth - 1];
+  };
+  // If we have only 1 known signbit between LHS/RHS we can try to figure
+  // out result signbit.
+  // NB: If we know both signbits `computeForAddCarry` gets the optimal result
+  // already.
+  if (IsSigned && !SignBitKnown(Res) &&
+      SignBitKnown(LHS) != SignBitKnown(RHS)) {
+    if (SignBitKnown(RHS))
+      std::swap(LHS, RHS);
+    KnownBits UnsignedLHS = LHS;
+    KnownBits UnsignedRHS = RHS;
+    UnsignedLHS.One.clearSignBit();
+    UnsignedLHS.Zero.setSignBit();
+    UnsignedRHS.One.clearSignBit();
+    UnsignedRHS.Zero.setSignBit();
+    KnownBits ResOf =
+        computeForAddCarry(UnsignedLHS, UnsignedRHS, /*CarryZero*/ !IsCeil,
+                           /*CarryOne*/ IsCeil);
+    // Assuming no overflow (which is the case since we extend the addition when
+    // taking the average):
+	// Neg + Neg -> Neg
+	// Neg + Pos -> Neg if the signbit doesn't overflow
+    if (LHS.isNegative() && ResOf.isNonNegative())
+      Res.makeNegative();
+	// Pos + Pos -> Pos
+	// Pos + Neg -> Pos if the signbit does overflow
+    else if (LHS.isNonNegative() && ResOf.isNegative())
+      Res.makeNonNegative();
+  }
+  return Res;
 }
 
 KnownBits KnownBits::avgFloorS(const KnownBits &LHS, const KnownBits &RHS) {
diff --git a/llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll b/llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll
index c2926eaffa58c5..f9618e1ddbc022 100644
--- a/llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll
+++ b/llvm/test/Analysis/ValueTracking/knownbits-sat-addsub.ll
@@ -142,14 +142,7 @@ define i1 @ssub_sat_low_bits(i8 %x, i8 %y) {
 
 define i1 @ssub_sat_fail_may_overflow(i8 %x, i8 %y) {
 ; CHECK-LABEL: @ssub_sat_fail_may_overflow(
-; CHECK-NEXT:    [[XX:%.*]] = and i8 [[X:%.*]], 15
-; CHECK-NEXT:    [[YY:%.*]] = and i8 [[Y:%.*]], 15
-; CHECK-NEXT:    [[LHS:%.*]] = or i8 [[XX]], 1
-; CHECK-NEXT:    [[RHS:%.*]] = and i8 [[YY]], -2
-; CHECK-NEXT:    [[EXP:%.*]] = call i8 @llvm.ssub.sat.i8(i8 [[LHS]], i8 [[RHS]])
-; CHECK-NEXT:    [[AND:%.*]] = and i8 [[EXP]], 1
-; CHECK-NEXT:    [[R:%.*]] = icmp eq i8 [[AND]], 0
-; CHECK-NEXT:    ret i1 [[R]]
+; CHECK-NEXT:    ret i1 false
 ;
   %xx = and i8 %x, 15
   %yy = and i8 %y, 15
diff --git a/llvm/unittests/Support/KnownBitsTest.cpp b/llvm/unittests/Support/KnownBitsTest.cpp
index b6e16f809ea779..e8be41519c5b95 100644
--- a/llvm/unittests/Support/KnownBitsTest.cpp
+++ b/llvm/unittests/Support/KnownBitsTest.cpp
@@ -306,14 +306,14 @@ TEST(KnownBitsTest, BinaryExhaustive) {
         return KnownBits::add(Known1, Known2);
       },
       [](const APInt &N1, const APInt &N2) { return N1 + N2; },
-      /*CheckOptimality=*/false);
+      /*CheckOptimality=*/true);
   testBinaryOpExhaustive(
       "sub",
       [](const KnownBits &Known1, const KnownBits &Known2) {
         return KnownBits::sub(Known1, Known2);
       },
       [](const APInt &N1, const APInt &N2) { return N1 - N2; },
-      /*CheckOptimality=*/false);
+      /*CheckOptimality=*/true);
   testBinaryOpExhaustive("umax", KnownBits::umax, APIntOps::umax);
   testBinaryOpExhaustive("umin", KnownBits::umin, APIntOps::umin);
   testBinaryOpExhaustive("smax", KnownBits::smax, APIntOps::smax);
@@ -385,26 +385,22 @@ TEST(KnownBitsTest, BinaryExhaustive) {
       "sadd_sat", KnownBits::sadd_sat,
       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
         return N1.sadd_sat(N2);
-      },
-      /*CheckOptimality=*/false);
+      });
   testBinaryOpExhaustive(
       "uadd_sat", KnownBits::uadd_sat,
       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
         return N1.uadd_sat(N2);
-      },
-      /*CheckOptimality=*/false);
+      });
   testBinaryOpExhaustive(
       "ssub_sat", KnownBits::ssub_sat,
       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
         return N1.ssub_sat(N2);
-      },
-      /*CheckOptimality=*/false);
+      });
   testBinaryOpExhaustive(
       "usub_sat", KnownBits::usub_sat,
       [](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
         return N1.usub_sat(N2);
-      },
-      /*CheckOptimality=*/false);
+      });
   testBinaryOpExhaustive(
       "shl",
       [](const KnownBits &Known1, const KnownBits &Known2) {
@@ -523,17 +519,15 @@ TEST(KnownBitsTest, BinaryExhaustive) {
       [](const APInt &N1, const APInt &N2) { return APIntOps::mulhu(N1, N2); },
       /*CheckOptimality=*/false);
 
-  testBinaryOpExhaustive("avgFloorS", KnownBits::avgFloorS, APIntOps::avgFloorS,
-                         false);
+  testBinaryOpExhaustive("avgFloorS", KnownBits::avgFloorS,
+                         APIntOps::avgFloorS);
 
-  testBinaryOpExhaustive("avgFloorU", KnownBits::avgFloorU, APIntOps::avgFloorU,
-                         false);
+  testBinaryOpExhaustive("avgFloorU", KnownBits::avgFloorU,
+                         APIntOps::avgFloorU);
 
-  testBinaryOpExhaustive("avgCeilU", KnownBits::avgCeilU, APIntOps::avgCeilU,
-                         false);
+  testBinaryOpExhaustive("avgCeilU", KnownBits::avgCeilU, APIntOps::avgCeilU);
 
-  testBinaryOpExhaustive("avgCeilS", KnownBits::avgCeilS, APIntOps::avgCeilS,
-                         false);
+  testBinaryOpExhaustive("avgCeilS", KnownBits::avgCeilS, APIntOps::avgCeilS);
 }
 
 TEST(KnownBitsTest, UnaryExhaustive) {

@goldsteinn goldsteinn changed the title goldsteinn/avg add sub opt knownbits [KnownBits] Make avg{Ceil,Floor}S and {s,u}{add,sub}_sat optimal Sep 27, 2024
@github-actions
Copy link

github-actions bot commented Sep 27, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to suggest to use KnownBits.makeNonNegative() - but for some reason it doesn't clear One.clearSignBit()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its been an annoyance for a file tbh, ill add a patch with forceNegative and forceNonNegative API

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See: #110389

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this related to the patch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this was just motivated by the fact that I noticed we had incorrectly labelled a lot of the tests. I pushed an NFC change to correctly label them.

@jayfoad
Copy link
Contributor

jayfoad commented Sep 30, 2024

There's a lot to understand here. Could you split it into two patches, for avg and add/sub sat? Or are they inextricably intertwined?

@goldsteinn
Copy link
Contributor Author

There's a lot to understand here. Could you split it into two patches, for avg and add/sub sat? Or are they inextricably intertwined?

They are not. Will split.

All we where missing was the signbit if we knew the incoming signbit
of either LHS or RHS.

Since the base addition in the average is with an extra bit width it
cannot overflow, we figure out the result sign based on the magnitude
of the input. If the negative component has a larger magnitude the
result is negative and vice versa for the positive case.
@goldsteinn goldsteinn force-pushed the goldsteinn/avg-add-sub-opt-knownbits branch from de85c14 to d037382 Compare September 30, 2024 18:45
@goldsteinn goldsteinn changed the title [KnownBits] Make avg{Ceil,Floor}S and {s,u}{add,sub}_sat optimal [KnownBits] Make avg{Ceil,Floor}S optimal Sep 30, 2024
@goldsteinn
Copy link
Contributor Author

There's a lot to understand here. Could you split it into two patches, for avg and add/sub sat? Or are they inextricably intertwined?

Split, just avg for now. Hadn't realized there was an already assigned outstanding issue to deal with addsub, so I'm going to give a few days for the original assignee to take over the addsub patches if they want to. If they don't ill post the addsub stuff in a few days.

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@goldsteinn
Copy link
Contributor Author

@jayfoad has a much cleaner impl

@goldsteinn goldsteinn closed this Oct 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llvm:analysis Includes value tracking, cost tables and constant folding llvm:support

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants