diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h index b730a36488780..48cc85e719421 100644 --- a/llvm/include/llvm/Analysis/ValueTracking.h +++ b/llvm/include/llvm/Analysis/ValueTracking.h @@ -102,6 +102,15 @@ LLVM_ABI void computeKnownBitsFromContext(const Value *V, KnownBits &Known, const SimplifyQuery &Q, unsigned Depth = 0); +/// Update \p Known with bits of \p V that are implied by \p Cmp. +/// Comparisons involving `trunc V` are handled specially: known +/// bits are computed for the truncated value and then extended to the bitwidth +/// of \p V. +LLVM_ABI void computeKnownBitsFromICmpCond(const Value *V, ICmpInst *Cmp, + KnownBits &Known, + const SimplifyQuery &SQ, + bool Invert); + /// Using KnownBits LHS/RHS produce the known bits for logic op (and/xor/or). LLVM_ABI KnownBits analyzeKnownBitsFromAndXorOr(const Operator *I, const KnownBits &KnownLHS, diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index 9cb6f19b9340c..5ab5f8cfccc7f 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -968,9 +968,9 @@ static void computeKnownBitsFromCmp(const Value *V, CmpInst::Predicate Pred, } } -static void computeKnownBitsFromICmpCond(const Value *V, ICmpInst *Cmp, - KnownBits &Known, - const SimplifyQuery &SQ, bool Invert) { +void llvm::computeKnownBitsFromICmpCond(const Value *V, ICmpInst *Cmp, + KnownBits &Known, + const SimplifyQuery &SQ, bool Invert) { ICmpInst::Predicate Pred = Invert ? Cmp->getInversePredicate() : Cmp->getPredicate(); Value *LHS = Cmp->getOperand(0); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index ba5568b00441b..bf4d4170ad464 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -15,11 +15,13 @@ #include "llvm/Analysis/CmpInstAnalysis.h" #include "llvm/Analysis/FloatingPointPredicateUtils.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/Support/KnownBits.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" #include "llvm/Transforms/Utils/Local.h" @@ -3575,6 +3577,41 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, return Builder.createIsFPClass(X, IsAnd ? FPClassTest::fcNormal : ~FPClassTest::fcNormal); + if (!IsLogical && IsAnd) { + auto TryCandidate = [&](Value *X) -> Value * { + if (!X->getType()->isIntegerTy()) + return nullptr; + + Type *Ty = X->getType(); + unsigned BitWidth = Ty->getScalarSizeInBits(); + + // KnownL and KnownR hold information deduced from the LHS icmp and RHS + // icmps, respectively + KnownBits KnownL(BitWidth), KnownR(BitWidth); + + computeKnownBitsFromICmpCond(X, LHS, KnownL, Q, /*Invert=*/false); + computeKnownBitsFromICmpCond(X, RHS, KnownR, Q, /*Invert=*/false); + + KnownBits Combined = KnownL.unionWith(KnownR); + + // Avoid stomping on cases where one icmp alone determines X. Those are + // handled by more specific InstCombine folds. + if (KnownL.isConstant() || KnownR.isConstant()) + return nullptr; + + if (!Combined.isConstant()) + return nullptr; + + APInt ConstVal = Combined.getConstant(); + return Builder.CreateICmpEQ(X, ConstantInt::get(Ty, ConstVal)); + }; + + if (Value *Res = TryCandidate(LHS0)) + return Res; + if (Value *Res = TryCandidate(RHS0)) + return Res; + } + return foldAndOrOfICmpsUsingRanges(LHS, RHS, IsAnd); } diff --git a/llvm/test/Transforms/InstCombine/and-or-icmps.ll b/llvm/test/Transforms/InstCombine/and-or-icmps.ll index 290e344acb980..9d69fadfa9627 100644 --- a/llvm/test/Transforms/InstCombine/and-or-icmps.ll +++ b/llvm/test/Transforms/InstCombine/and-or-icmps.ll @@ -702,9 +702,9 @@ define i1 @PR42691_10_logical(i32 %x) { define i1 @substitute_constant_and_eq_eq(i8 %x, i8 %y) { ; CHECK-LABEL: @substitute_constant_and_eq_eq( -; CHECK-NEXT: [[C1:%.*]] = icmp eq i8 [[X:%.*]], 42 ; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i8 [[Y:%.*]], 42 -; CHECK-NEXT: [[R:%.*]] = and i1 [[C1]], [[TMP1]] +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[Y1:%.*]], 42 +; CHECK-NEXT: [[R:%.*]] = and i1 [[TMP1]], [[TMP2]] ; CHECK-NEXT: ret i1 [[R]] ; %c1 = icmp eq i8 %x, 42 @@ -728,9 +728,9 @@ define i1 @substitute_constant_and_eq_eq_logical(i8 %x, i8 %y) { define i1 @substitute_constant_and_eq_eq_commute(i8 %x, i8 %y) { ; CHECK-LABEL: @substitute_constant_and_eq_eq_commute( -; CHECK-NEXT: [[C1:%.*]] = icmp eq i8 [[X:%.*]], 42 ; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i8 [[Y:%.*]], 42 -; CHECK-NEXT: [[R:%.*]] = and i1 [[C1]], [[TMP1]] +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[Y1:%.*]], 42 +; CHECK-NEXT: [[R:%.*]] = and i1 [[TMP1]], [[TMP2]] ; CHECK-NEXT: ret i1 [[R]] ; %c1 = icmp eq i8 %x, 42 @@ -741,9 +741,9 @@ define i1 @substitute_constant_and_eq_eq_commute(i8 %x, i8 %y) { define i1 @substitute_constant_and_eq_eq_commute_logical(i8 %x, i8 %y) { ; CHECK-LABEL: @substitute_constant_and_eq_eq_commute_logical( -; CHECK-NEXT: [[C1:%.*]] = icmp eq i8 [[X:%.*]], 42 ; CHECK-NEXT: [[TMP1:%.*]] = icmp eq i8 [[Y:%.*]], 42 -; CHECK-NEXT: [[R:%.*]] = and i1 [[C1]], [[TMP1]] +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i8 [[Y1:%.*]], 42 +; CHECK-NEXT: [[R:%.*]] = and i1 [[TMP1]], [[TMP2]] ; CHECK-NEXT: ret i1 [[R]] ; %c1 = icmp eq i8 %x, 42 @@ -1392,12 +1392,12 @@ define i1 @bitwise_and_bitwise_and_icmps(i8 %x, i8 %y, i8 %z) { define i1 @bitwise_and_bitwise_and_icmps_comm1(i8 %x, i8 %y, i8 %z) { ; CHECK-LABEL: @bitwise_and_bitwise_and_icmps_comm1( -; CHECK-NEXT: [[C1:%.*]] = icmp eq i8 [[Y:%.*]], 42 +; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i8 [[Y:%.*]], 42 ; CHECK-NEXT: [[Z_SHIFT:%.*]] = shl nuw i8 1, [[Z:%.*]] ; CHECK-NEXT: [[TMP1:%.*]] = or i8 [[Z_SHIFT]], 1 ; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[X:%.*]], [[TMP1]] -; CHECK-NEXT: [[TMP3:%.*]] = icmp eq i8 [[TMP2]], [[TMP1]] -; CHECK-NEXT: [[AND2:%.*]] = and i1 [[C1]], [[TMP3]] +; CHECK-NEXT: [[TMP4:%.*]] = icmp eq i8 [[TMP2]], [[TMP1]] +; CHECK-NEXT: [[AND2:%.*]] = and i1 [[TMP3]], [[TMP4]] ; CHECK-NEXT: ret i1 [[AND2]] ; %c1 = icmp eq i8 %y, 42 @@ -3721,3 +3721,28 @@ define i1 @merge_range_check_or(i8 %a) { %and = or i1 %cmp1, %cmp2 ret i1 %and } + +; Just a very complicated way of checking if v1 == 0. +define i1 @complicated_zero_equality_test(i64 %v1) { +; CHECK-LABEL: @complicated_zero_equality_test( +; CHECK-NEXT: [[V5:%.*]] = icmp eq i64 [[V1:%.*]], 0 +; CHECK-NEXT: ret i1 [[V5]] +; + %v2 = trunc i64 %v1 to i32 + %v3 = icmp eq i32 %v2, 0 + %v4 = icmp ult i64 %v1, 4294967296 ; 2 ^ 32 + %v5 = and i1 %v4, %v3 + ret i1 %v5 +} + +define i1 @commuted_complicated_zero_equality_test(i64 %v1) { +; CHECK-LABEL: @commuted_complicated_zero_equality_test( +; CHECK-NEXT: [[V5:%.*]] = icmp eq i64 [[V1:%.*]], 0 +; CHECK-NEXT: ret i1 [[V5]] +; + %v2 = trunc i64 %v1 to i32 + %v3 = icmp ult i64 %v1, 4294967296 ; 2 ^ 32 + %v4 = icmp eq i32 %v2, 0 + %v5 = and i1 %v4, %v3 + ret i1 %v5 +}