Skip to content

Conversation

@wermos
Copy link
Contributor

@wermos wermos commented Dec 8, 2025

Resolves #170020.

I'm not exactly sure what kind of Alive2 proof is required when the optimization has to do with KnownBits stuff, so I'm copying over the Alive2 proof for the specific case discussed in the issue: https://alive2.llvm.org/ce/z/K59kAt

I followed the suggestion given here:

I'd suggest reusing computeKnownBitsFromICmpCond to compute known bits inferred from both conditions. If the union of known bits is a constant, convert the and/or into an equality test. It would be a bit tricky to select a suitable X.

To do this, I had to make computeKnownBitsFromICmpCond a part of the ValueTracking.h header.

I'm also not sure if more tests are required or not.

@wermos wermos requested a review from nikic as a code owner December 8, 2025 20:08
@llvmbot llvmbot added llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms labels Dec 8, 2025
@llvmbot
Copy link
Member

llvmbot commented Dec 8, 2025

@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-llvm-transforms

Author: Tirthankar Mazumder (wermos)

Changes

Addresses #170020.

I'm not exactly sure what kind of Alive2 proof is required when the optimization has to do with KnownBits stuff, so I'm copying over the Alive2 proof for the specific case discussed in the issue: https://alive2.llvm.org/ce/z/K59kAt

I followed the suggestion given here:
> I'd suggest reusing computeKnownBitsFromICmpCond to compute known bits inferred from both conditions. If the union of known bits is a constant, convert the and/or into an equality test. It would be a bit tricky to select a suitable X.

To do this, I had to make computeKnownBitsFromICmpCond a part of the ValueTracking.h header.

I'm also not sure if more tests are required or not.


Full diff: https://github.com/llvm/llvm-project/pull/171195.diff

4 Files Affected:

  • (modified) llvm/include/llvm/Analysis/ValueTracking.h (+9)
  • (modified) llvm/lib/Analysis/ValueTracking.cpp (+3-3)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+40)
  • (modified) llvm/test/Transforms/InstCombine/and-or-icmps.ll (+34-9)
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index b730a36488780..48cc85e719421 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -102,6 +102,15 @@ LLVM_ABI void computeKnownBitsFromContext(const Value *V, KnownBits &Known,
                                           const SimplifyQuery &Q,
                                           unsigned Depth = 0);
 
+/// Update \p Known with bits of \p V that are implied by \p Cmp.
+/// Comparisons involving `trunc V` are handled specially: known
+/// bits are computed for the truncated value and then extended to the bitwidth
+/// of \p V.
+LLVM_ABI void computeKnownBitsFromICmpCond(const Value *V, ICmpInst *Cmp,
+                                           KnownBits &Known,
+                                           const SimplifyQuery &SQ,
+                                           bool Invert);
+
 /// Using KnownBits LHS/RHS produce the known bits for logic op (and/xor/or).
 LLVM_ABI KnownBits analyzeKnownBitsFromAndXorOr(const Operator *I,
                                                 const KnownBits &KnownLHS,
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 9cb6f19b9340c..5ab5f8cfccc7f 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -968,9 +968,9 @@ static void computeKnownBitsFromCmp(const Value *V, CmpInst::Predicate Pred,
   }
 }
 
-static void computeKnownBitsFromICmpCond(const Value *V, ICmpInst *Cmp,
-                                         KnownBits &Known,
-                                         const SimplifyQuery &SQ, bool Invert) {
+void llvm::computeKnownBitsFromICmpCond(const Value *V, ICmpInst *Cmp,
+                                        KnownBits &Known,
+                                        const SimplifyQuery &SQ, bool Invert) {
   ICmpInst::Predicate Pred =
       Invert ? Cmp->getInversePredicate() : Cmp->getPredicate();
   Value *LHS = Cmp->getOperand(0);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index ba5568b00441b..fa7c66d736c28 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -15,11 +15,13 @@
 #include "llvm/Analysis/CmpInstAnalysis.h"
 #include "llvm/Analysis/FloatingPointPredicateUtils.h"
 #include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/ValueTracking.h"
 #include "llvm/IR/ConstantRange.h"
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/Support/KnownBits.h"
 #include "llvm/Transforms/InstCombine/InstCombiner.h"
 #include "llvm/Transforms/Utils/Local.h"
 
@@ -3376,9 +3378,13 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
   Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0);
   Value *LHS1 = LHS->getOperand(1), *RHS1 = RHS->getOperand(1);
 
+  // dbgs() << "LHS0 = " << *LHS0 << "\nLHS1 = " << *LHS1 << '\n';
+  // dbgs() << "RHS0 = " << *RHS0 << "\nRHS1 = " << *RHS1 << '\n';
+
   const APInt *LHSC = nullptr, *RHSC = nullptr;
   match(LHS1, m_APInt(LHSC));
   match(RHS1, m_APInt(RHSC));
+  // dbgs() << "LHSC = " << *LHSC << "\nRHSC = " << *RHSC << '\n';
 
   // (icmp1 A, B) | (icmp2 A, B) --> (icmp3 A, B)
   // (icmp1 A, B) & (icmp2 A, B) --> (icmp3 A, B)
@@ -3575,6 +3581,40 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
     return Builder.createIsFPClass(X, IsAnd ? FPClassTest::fcNormal
                                             : ~FPClassTest::fcNormal);
 
+  if (!IsLogical && IsAnd) {
+    auto TryCandidate = [&](Value *X) -> Value * {
+      if (!X->getType()->isIntegerTy())
+        return nullptr;
+
+      Type *Ty = X->getType();
+      unsigned BitWidth = Ty->getScalarSizeInBits();
+
+      // KnownL and KnownR hold information deduced from the LHS icmp and RHS
+      // icmps, respectively
+      KnownBits KnownL(BitWidth), KnownR(BitWidth);
+
+      computeKnownBitsFromICmpCond(X, LHS, KnownL, Q, /*Invert=*/false);
+      computeKnownBitsFromICmpCond(X, RHS, KnownR, Q, /*Invert=*/false);
+
+      KnownBits Combined = KnownL.unionWith(KnownR);
+
+      // Avoid stomping on cases where one icmp alone determines X. Those are handled by more specific InstCombine folds.
+      if (KnownL.isConstant() || KnownR.isConstant())
+        return nullptr;
+
+      if (!Combined.isConstant())
+        return nullptr;
+
+      APInt ConstVal = Combined.getConstant();
+      return Builder.CreateICmpEQ(X, ConstantInt::get(Ty, ConstVal));
+    };
+
+    if (Value *Res = TryCandidate(LHS0))
+      return Res;
+    if (Value *Res = TryCandidate(RHS0))
+      return Res;
+  }
+
   return foldAndOrOfICmpsUsingRanges(LHS, RHS, IsAnd);
 }
 
diff --git a/llvm/test/Transforms/InstCombine/and-or-icmps.ll b/llvm/test/Transforms/InstCombine/and-or-icmps.ll
index 290e344acb980..9d69fadfa9627 100644
--- a/llvm/test/Transforms/InstCombine/and-or-icmps.ll
+++ b/llvm/test/Transforms/InstCombine/and-or-icmps.ll
@@ -702,9 +702,9 @@ define i1 @PR42691_10_logical(i32 %x) {
 
 define i1 @substitute_constant_and_eq_eq(i8 %x, i8 %y) {
 ; CHECK-LABEL: @substitute_constant_and_eq_eq(
-; CHECK-NEXT:    [[C1:%.*]] = icmp eq i8 [[X:%.*]], 42
 ; CHECK-NEXT:    [[TMP1:%.*]] = icmp eq i8 [[Y:%.*]], 42
-; CHECK-NEXT:    [[R:%.*]] = and i1 [[C1]], [[TMP1]]
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp eq i8 [[Y1:%.*]], 42
+; CHECK-NEXT:    [[R:%.*]] = and i1 [[TMP1]], [[TMP2]]
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %c1 = icmp eq i8 %x, 42
@@ -728,9 +728,9 @@ define i1 @substitute_constant_and_eq_eq_logical(i8 %x, i8 %y) {
 
 define i1 @substitute_constant_and_eq_eq_commute(i8 %x, i8 %y) {
 ; CHECK-LABEL: @substitute_constant_and_eq_eq_commute(
-; CHECK-NEXT:    [[C1:%.*]] = icmp eq i8 [[X:%.*]], 42
 ; CHECK-NEXT:    [[TMP1:%.*]] = icmp eq i8 [[Y:%.*]], 42
-; CHECK-NEXT:    [[R:%.*]] = and i1 [[C1]], [[TMP1]]
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp eq i8 [[Y1:%.*]], 42
+; CHECK-NEXT:    [[R:%.*]] = and i1 [[TMP1]], [[TMP2]]
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %c1 = icmp eq i8 %x, 42
@@ -741,9 +741,9 @@ define i1 @substitute_constant_and_eq_eq_commute(i8 %x, i8 %y) {
 
 define i1 @substitute_constant_and_eq_eq_commute_logical(i8 %x, i8 %y) {
 ; CHECK-LABEL: @substitute_constant_and_eq_eq_commute_logical(
-; CHECK-NEXT:    [[C1:%.*]] = icmp eq i8 [[X:%.*]], 42
 ; CHECK-NEXT:    [[TMP1:%.*]] = icmp eq i8 [[Y:%.*]], 42
-; CHECK-NEXT:    [[R:%.*]] = and i1 [[C1]], [[TMP1]]
+; CHECK-NEXT:    [[TMP2:%.*]] = icmp eq i8 [[Y1:%.*]], 42
+; CHECK-NEXT:    [[R:%.*]] = and i1 [[TMP1]], [[TMP2]]
 ; CHECK-NEXT:    ret i1 [[R]]
 ;
   %c1 = icmp eq i8 %x, 42
@@ -1392,12 +1392,12 @@ define i1 @bitwise_and_bitwise_and_icmps(i8 %x, i8 %y, i8 %z) {
 
 define i1 @bitwise_and_bitwise_and_icmps_comm1(i8 %x, i8 %y, i8 %z) {
 ; CHECK-LABEL: @bitwise_and_bitwise_and_icmps_comm1(
-; CHECK-NEXT:    [[C1:%.*]] = icmp eq i8 [[Y:%.*]], 42
+; CHECK-NEXT:    [[TMP3:%.*]] = icmp eq i8 [[Y:%.*]], 42
 ; CHECK-NEXT:    [[Z_SHIFT:%.*]] = shl nuw i8 1, [[Z:%.*]]
 ; CHECK-NEXT:    [[TMP1:%.*]] = or i8 [[Z_SHIFT]], 1
 ; CHECK-NEXT:    [[TMP2:%.*]] = and i8 [[X:%.*]], [[TMP1]]
-; CHECK-NEXT:    [[TMP3:%.*]] = icmp eq i8 [[TMP2]], [[TMP1]]
-; CHECK-NEXT:    [[AND2:%.*]] = and i1 [[C1]], [[TMP3]]
+; CHECK-NEXT:    [[TMP4:%.*]] = icmp eq i8 [[TMP2]], [[TMP1]]
+; CHECK-NEXT:    [[AND2:%.*]] = and i1 [[TMP3]], [[TMP4]]
 ; CHECK-NEXT:    ret i1 [[AND2]]
 ;
   %c1 = icmp eq i8 %y, 42
@@ -3721,3 +3721,28 @@ define i1 @merge_range_check_or(i8 %a) {
   %and = or i1 %cmp1, %cmp2
   ret i1 %and
 }
+
+; Just a very complicated way of checking if v1 == 0.
+define i1 @complicated_zero_equality_test(i64 %v1) {
+; CHECK-LABEL: @complicated_zero_equality_test(
+; CHECK-NEXT:    [[V5:%.*]] = icmp eq i64 [[V1:%.*]], 0
+; CHECK-NEXT:    ret i1 [[V5]]
+;
+  %v2 = trunc i64 %v1 to i32
+  %v3 = icmp eq i32 %v2, 0
+  %v4 = icmp ult i64 %v1, 4294967296 ; 2 ^ 32
+  %v5 = and i1 %v4, %v3
+  ret i1 %v5
+}
+
+define i1 @commuted_complicated_zero_equality_test(i64 %v1) {
+; CHECK-LABEL: @commuted_complicated_zero_equality_test(
+; CHECK-NEXT:    [[V5:%.*]] = icmp eq i64 [[V1:%.*]], 0
+; CHECK-NEXT:    ret i1 [[V5]]
+;
+  %v2 = trunc i64 %v1 to i32
+  %v3 = icmp ult i64 %v1, 4294967296 ; 2 ^ 32
+  %v4 = icmp eq i32 %v2, 0
+  %v5 = and i1 %v4, %v3
+  ret i1 %v5
+}

@github-actions
Copy link

github-actions bot commented Dec 8, 2025

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Developer Policy and LLVM Discourse for more information.

@github-actions
Copy link

github-actions bot commented Dec 8, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@wermos
Copy link
Contributor Author

wermos commented Dec 8, 2025

I've addressed the email thing as well.

@wermos
Copy link
Contributor Author

wermos commented Dec 8, 2025

Ping @dtcxzyw for review.

Comment on lines +3581 to +3607
auto TryCandidate = [&](Value *X) -> Value * {
if (!X->getType()->isIntegerTy())
return nullptr;

Type *Ty = X->getType();
unsigned BitWidth = Ty->getScalarSizeInBits();

// KnownL and KnownR hold information deduced from the LHS icmp and RHS
// icmps, respectively
KnownBits KnownL(BitWidth), KnownR(BitWidth);

computeKnownBitsFromICmpCond(X, LHS, KnownL, Q, /*Invert=*/false);
computeKnownBitsFromICmpCond(X, RHS, KnownR, Q, /*Invert=*/false);

KnownBits Combined = KnownL.unionWith(KnownR);

// Avoid stomping on cases where one icmp alone determines X. Those are
// handled by more specific InstCombine folds.
if (KnownL.isConstant() || KnownR.isConstant())
return nullptr;

if (!Combined.isConstant())
return nullptr;

APInt ConstVal = Combined.getConstant();
return Builder.CreateICmpEQ(X, ConstantInt::get(Ty, ConstVal));
};
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this part fine as a lambda, or should it be its own helper function? If it should be its own helper function, then what should it be called?

@andjo403
Copy link
Contributor

andjo403 commented Dec 8, 2025

if the trunc is repalced by an and this is already folded see https://alive2.llvm.org/ce/z/Whfa65
I assume it is handled by

/// Handle (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E).
/// Return the pattern classes (from MaskedICmpType) for the left hand side and
/// the right hand side as a pair.
/// LHS and RHS are the left hand side and the right hand side ICmps and PredL
/// and PredR are their predicates, respectively.
static std::optional<std::pair<unsigned, unsigned>>
getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, Value *&D, Value *&E,
Value *LHS, Value *RHS, ICmpInst::Predicate &PredL,
ICmpInst::Predicate &PredR) {
maybe can add support for trunc to a mask there? Do not know that the best solutions is I only thought of that function when I looked at the described fold.

@zwuis
Copy link
Contributor

zwuis commented Dec 9, 2025

Addresses #170020.

You could use this format.

@dtcxzyw
Copy link
Member

dtcxzyw commented Dec 9, 2025

maybe can add support for trunc to a mask there?

Oh yes, we can simply handle this in decomposeBitTestICmp. The following patch works, but I think it should be moved into decomposeBitTestICmp.

diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp
index a1a79e5685f8..362e7d0508a5 100644
--- a/llvm/lib/Analysis/CmpInstAnalysis.cpp
+++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp
@@ -193,9 +193,25 @@ std::optional<DecomposedBitTest> llvm::decomposeBitTest(Value *Cond,
     // Don't allow pointers. Splat vectors are fine.
     if (!ICmp->getOperand(0)->getType()->isIntOrIntVectorTy())
       return std::nullopt;
-    return decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
+    if (auto Res = decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1),
                                 ICmp->getPredicate(), LookThruTrunc,
-                                AllowNonZeroC, DecomposeAnd);
+                                AllowNonZeroC, DecomposeAnd)) {
+      return Res;
+    }
+
+    CmpPredicate Pred;
+    Value *X;
+    const APInt *RHSC;
+    if (LookThruTrunc && match(Cond, m_ICmp(Pred, m_Trunc(m_Value(X)), 
+                               m_APInt(RHSC))) && (AllowNonZeroC || RHSC->isZero()) && ICmpInst::isEquality(Pred)) {
+      DecomposedBitTest Result;
+      Result.X = X;
+      unsigned BitWidth = X->getType()->getScalarSizeInBits();
+      Result.Mask = APInt::getLowBitsSet(BitWidth, RHSC->getBitWidth());
+      Result.C = RHSC->zext(BitWidth);
+      Result.Pred = Pred;
+      return Result;
+    }
   }
   Value *X;
   if (Cond->getType()->isIntOrIntVectorTy(1) &&

@wermos
Copy link
Contributor Author

wermos commented Dec 10, 2025

Alright, I'll work on moving the patch you shared into decomposeBitTestICmp.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llvm:analysis Includes value tracking, cost tables and constant folding llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:transforms

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Missed Optimization: Fold (x < 2^32) & (trunc(x to i32) == 0) into x == 0

5 participants