diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index 01b0a089aab71..1a5bbbc7dfceb 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -3070,6 +3070,69 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, return nullptr; } +enum class MonotonicType { GreaterEq, LowerEq }; + +/// Get values V_i such that V uge V_i (GreaterEq) or V ule V_i (LowerEq). +static void getUnsignedMonotonicValues(SmallPtrSetImpl &Res, Value *V, + MonotonicType Type, unsigned Depth = 0) { + if (!Res.insert(V).second) + return; + + // Can be increased if useful. + if (++Depth > 1) + return; + + auto *I = dyn_cast(V); + if (!I) + return; + + Value *X, *Y; + if (Type == MonotonicType::GreaterEq) { + if (match(I, m_Or(m_Value(X), m_Value(Y))) || + match(I, m_Intrinsic(m_Value(X), m_Value(Y)))) { + getUnsignedMonotonicValues(Res, X, Type, Depth); + getUnsignedMonotonicValues(Res, Y, Type, Depth); + } + } else { + assert(Type == MonotonicType::LowerEq); + switch (I->getOpcode()) { + case Instruction::And: + getUnsignedMonotonicValues(Res, I->getOperand(0), Type, Depth); + getUnsignedMonotonicValues(Res, I->getOperand(1), Type, Depth); + break; + case Instruction::URem: + case Instruction::UDiv: + case Instruction::LShr: + getUnsignedMonotonicValues(Res, I->getOperand(0), Type, Depth); + break; + case Instruction::Call: + if (match(I, m_Intrinsic(m_Value(X)))) + getUnsignedMonotonicValues(Res, X, Type, Depth); + break; + default: + break; + } + } +} + +static Value *simplifyICmpUsingMonotonicValues(ICmpInst::Predicate Pred, + Value *LHS, Value *RHS) { + if (Pred != ICmpInst::ICMP_UGE && Pred != ICmpInst::ICMP_ULT) + return nullptr; + + // We have LHS uge GreaterValues and LowerValues uge RHS. If any of the + // GreaterValues and LowerValues are the same, it follows that LHS uge RHS. + SmallPtrSet GreaterValues; + SmallPtrSet LowerValues; + getUnsignedMonotonicValues(GreaterValues, LHS, MonotonicType::GreaterEq); + getUnsignedMonotonicValues(LowerValues, RHS, MonotonicType::LowerEq); + for (Value *GV : GreaterValues) + if (LowerValues.contains(GV)) + return ConstantInt::getBool(getCompareTy(LHS), + Pred == ICmpInst::ICMP_UGE); + return nullptr; +} + static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred, BinaryOperator *LBO, Value *RHS, const SimplifyQuery &Q, @@ -3079,11 +3142,6 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred, Value *Y = nullptr; // icmp pred (or X, Y), X if (match(LBO, m_c_Or(m_Value(Y), m_Specific(RHS)))) { - if (Pred == ICmpInst::ICMP_ULT) - return getFalse(ITy); - if (Pred == ICmpInst::ICMP_UGE) - return getTrue(ITy); - if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) { KnownBits RHSKnown = computeKnownBits(RHS, /* Depth */ 0, Q); KnownBits YKnown = computeKnownBits(Y, /* Depth */ 0, Q); @@ -3094,14 +3152,6 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred, } } - // icmp pred (and X, Y), X - if (match(LBO, m_c_And(m_Value(), m_Specific(RHS)))) { - if (Pred == ICmpInst::ICMP_UGT) - return getFalse(ITy); - if (Pred == ICmpInst::ICMP_ULE) - return getTrue(ITy); - } - // icmp pred (urem X, Y), Y if (match(LBO, m_URem(m_Value(), m_Specific(RHS)))) { switch (Pred) { @@ -3132,27 +3182,6 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred, } } - // icmp pred (urem X, Y), X - if (match(LBO, m_URem(m_Specific(RHS), m_Value()))) { - if (Pred == ICmpInst::ICMP_ULE) - return getTrue(ITy); - if (Pred == ICmpInst::ICMP_UGT) - return getFalse(ITy); - } - - // x >>u y <=u x --> true. - // x >>u y >u x --> false. - // x udiv y <=u x --> true. - // x udiv y >u x --> false. - if (match(LBO, m_LShr(m_Specific(RHS), m_Value())) || - match(LBO, m_UDiv(m_Specific(RHS), m_Value()))) { - // icmp pred (X op Y), X - if (Pred == ICmpInst::ICMP_UGT) - return getFalse(ITy); - if (Pred == ICmpInst::ICMP_ULE) - return getTrue(ITy); - } - // If x is nonzero: // x >>u C true for C != 0. // x >>u C != x --> true for C != 0. @@ -3172,14 +3201,12 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred, break; case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_UGE: + case ICmpInst::ICMP_UGT: return getFalse(ITy); case ICmpInst::ICMP_NE: case ICmpInst::ICMP_ULT: - return getTrue(ITy); - case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_ULE: - // UGT/ULE are handled by the more general case just above - llvm_unreachable("Unexpected UGT/ULE, should have been handled"); + return getTrue(ITy); } } } @@ -3702,13 +3729,6 @@ static Value *simplifyICmpWithIntrinsicOnLHS(CmpInst::Predicate Pred, switch (II->getIntrinsicID()) { case Intrinsic::uadd_sat: - // uadd.sat(X, Y) uge X, uadd.sat(X, Y) uge Y - if (II->getArgOperand(0) == RHS || II->getArgOperand(1) == RHS) { - if (Pred == ICmpInst::ICMP_UGE) - return ConstantInt::getTrue(getCompareTy(II)); - if (Pred == ICmpInst::ICMP_ULT) - return ConstantInt::getFalse(getCompareTy(II)); - } // uadd.sat(X, Y) uge X + Y if (match(RHS, m_c_Add(m_Specific(II->getArgOperand(0)), m_Specific(II->getArgOperand(1))))) { @@ -3719,13 +3739,6 @@ static Value *simplifyICmpWithIntrinsicOnLHS(CmpInst::Predicate Pred, } return nullptr; case Intrinsic::usub_sat: - // usub.sat(X, Y) ule X - if (II->getArgOperand(0) == RHS) { - if (Pred == ICmpInst::ICMP_ULE) - return ConstantInt::getTrue(getCompareTy(II)); - if (Pred == ICmpInst::ICMP_UGT) - return ConstantInt::getFalse(getCompareTy(II)); - } // usub.sat(X, Y) ule X - Y if (match(RHS, m_Sub(m_Specific(II->getArgOperand(0)), m_Specific(II->getArgOperand(1))))) { @@ -4030,6 +4043,12 @@ static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, ICmpInst::getSwappedPredicate(Pred), RHS, LHS)) return V; + if (Value *V = simplifyICmpUsingMonotonicValues(Pred, LHS, RHS)) + return V; + if (Value *V = simplifyICmpUsingMonotonicValues( + ICmpInst::getSwappedPredicate(Pred), RHS, LHS)) + return V; + if (Value *V = simplifyICmpWithDominatingAssume(Pred, LHS, RHS, Q)) return V; diff --git a/llvm/test/Transforms/InstSimplify/compare.ll b/llvm/test/Transforms/InstSimplify/compare.ll index 21653d800dce2..5e2b15720e464 100644 --- a/llvm/test/Transforms/InstSimplify/compare.ll +++ b/llvm/test/Transforms/InstSimplify/compare.ll @@ -780,6 +780,32 @@ define i1 @lshr_nonzero_ult(i32 %x) { ret i1 %cmp } +define i1 @lshr_nonzero_ugt(i32 %x) { +; CHECK-LABEL: @lshr_nonzero_ugt( +; CHECK-NEXT: [[X_NE_0:%.*]] = icmp ne i32 [[X:%.*]], 0 +; CHECK-NEXT: call void @llvm.assume(i1 [[X_NE_0]]) +; CHECK-NEXT: ret i1 false +; + %x_ne_0 = icmp ne i32 %x, 0 + call void @llvm.assume(i1 %x_ne_0) + %lhs = lshr i32 %x, 1 + %cmp = icmp ugt i32 %lhs, %x + ret i1 %cmp +} + +define i1 @lshr_nonzero_ule(i32 %x) { +; CHECK-LABEL: @lshr_nonzero_ule( +; CHECK-NEXT: [[X_NE_0:%.*]] = icmp ne i32 [[X:%.*]], 0 +; CHECK-NEXT: call void @llvm.assume(i1 [[X_NE_0]]) +; CHECK-NEXT: ret i1 true +; + %x_ne_0 = icmp ne i32 %x, 0 + call void @llvm.assume(i1 %x_ne_0) + %lhs = lshr i32 %x, 1 + %cmp = icmp ule i32 %lhs, %x + ret i1 %cmp +} + ; Negative test - unknown shift amount define i1 @lshr_nonzero_neg_unknown(i32 %x, i32 %c) { ; CHECK-LABEL: @lshr_nonzero_neg_unknown( diff --git a/llvm/test/Transforms/InstSimplify/icmp-monotonic.ll b/llvm/test/Transforms/InstSimplify/icmp-monotonic.ll index a1daa6bd7b402..e1a4ee91bd15c 100644 --- a/llvm/test/Transforms/InstSimplify/icmp-monotonic.ll +++ b/llvm/test/Transforms/InstSimplify/icmp-monotonic.ll @@ -4,10 +4,7 @@ define i1 @lshr_or_ule(i32 %x, i32 %y, i32 %z) { ; CHECK-LABEL: define i1 @lshr_or_ule( ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) { -; CHECK-NEXT: [[OP1:%.*]] = lshr i32 [[X]], [[Y]] -; CHECK-NEXT: [[OP2:%.*]] = or i32 [[X]], [[Z]] -; CHECK-NEXT: [[CMP:%.*]] = icmp ule i32 [[OP1]], [[OP2]] -; CHECK-NEXT: ret i1 [[CMP]] +; CHECK-NEXT: ret i1 true ; %op1 = lshr i32 %x, %y %op2 = or i32 %x, %z @@ -18,10 +15,7 @@ define i1 @lshr_or_ule(i32 %x, i32 %y, i32 %z) { define i1 @lshr_or_uge_swapped(i32 %x, i32 %y, i32 %z) { ; CHECK-LABEL: define i1 @lshr_or_uge_swapped( ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) { -; CHECK-NEXT: [[OP1:%.*]] = lshr i32 [[X]], [[Y]] -; CHECK-NEXT: [[OP2:%.*]] = or i32 [[X]], [[Z]] -; CHECK-NEXT: [[CMP:%.*]] = icmp uge i32 [[OP2]], [[OP1]] -; CHECK-NEXT: ret i1 [[CMP]] +; CHECK-NEXT: ret i1 true ; %op1 = lshr i32 %x, %y %op2 = or i32 %x, %z @@ -32,10 +26,7 @@ define i1 @lshr_or_uge_swapped(i32 %x, i32 %y, i32 %z) { define i1 @lshr_or_ugt(i32 %x, i32 %y, i32 %z) { ; CHECK-LABEL: define i1 @lshr_or_ugt( ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) { -; CHECK-NEXT: [[OP1:%.*]] = lshr i32 [[X]], [[Y]] -; CHECK-NEXT: [[OP2:%.*]] = or i32 [[X]], [[Z]] -; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[OP1]], [[OP2]] -; CHECK-NEXT: ret i1 [[CMP]] +; CHECK-NEXT: ret i1 false ; %op1 = lshr i32 %x, %y %op2 = or i32 %x, %z @@ -74,10 +65,7 @@ define i1 @lshr_or_sle_wrong_pred(i32 %x, i32 %y, i32 %z) { define i1 @lshr_or_swapped_ule(i32 %x, i32 %y, i32 %z) { ; CHECK-LABEL: define i1 @lshr_or_swapped_ule( ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) { -; CHECK-NEXT: [[OP1:%.*]] = lshr i32 [[X]], [[Y]] -; CHECK-NEXT: [[OP2:%.*]] = or i32 [[Z]], [[X]] -; CHECK-NEXT: [[CMP:%.*]] = icmp ule i32 [[OP1]], [[OP2]] -; CHECK-NEXT: ret i1 [[CMP]] +; CHECK-NEXT: ret i1 true ; %op1 = lshr i32 %x, %y %op2 = or i32 %z, %x @@ -102,10 +90,7 @@ define i1 @lshr_or_ule_invalid_swapped(i32 %x, i32 %y, i32 %z) { define i1 @and_uadd_sat_ule(i32 %x, i32 %y, i32 %z) { ; CHECK-LABEL: define i1 @and_uadd_sat_ule( ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) { -; CHECK-NEXT: [[OP1:%.*]] = and i32 [[X]], [[Y]] -; CHECK-NEXT: [[OP2:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X]], i32 [[Z]]) -; CHECK-NEXT: [[CMP:%.*]] = icmp ule i32 [[OP1]], [[OP2]] -; CHECK-NEXT: ret i1 [[CMP]] +; CHECK-NEXT: ret i1 true ; %op1 = and i32 %x, %y %op2 = call i32 @llvm.uadd.sat(i32 %x, i32 %z) @@ -116,10 +101,7 @@ define i1 @and_uadd_sat_ule(i32 %x, i32 %y, i32 %z) { define i1 @urem_or_ule(i32 %x, i32 %y, i32 %z) { ; CHECK-LABEL: define i1 @urem_or_ule( ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) { -; CHECK-NEXT: [[OP1:%.*]] = urem i32 [[X]], [[Y]] -; CHECK-NEXT: [[OP2:%.*]] = or i32 [[X]], [[Z]] -; CHECK-NEXT: [[CMP:%.*]] = icmp ule i32 [[OP1]], [[OP2]] -; CHECK-NEXT: ret i1 [[CMP]] +; CHECK-NEXT: ret i1 true ; %op1 = urem i32 %x, %y %op2 = or i32 %x, %z @@ -144,10 +126,7 @@ define i1 @urem_or_ule_invalid_swapped(i32 %x, i32 %y, i32 %z) { define i1 @udiv_or_ule(i32 %x, i32 %y, i32 %z) { ; CHECK-LABEL: define i1 @udiv_or_ule( ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) { -; CHECK-NEXT: [[OP1:%.*]] = udiv i32 [[X]], [[Y]] -; CHECK-NEXT: [[OP2:%.*]] = or i32 [[X]], [[Z]] -; CHECK-NEXT: [[CMP:%.*]] = icmp ule i32 [[OP1]], [[OP2]] -; CHECK-NEXT: ret i1 [[CMP]] +; CHECK-NEXT: ret i1 true ; %op1 = udiv i32 %x, %y %op2 = or i32 %x, %z @@ -172,10 +151,7 @@ define i1 @udiv_or_ule_invalid_swapped(i32 %x, i32 %y, i32 %z) { define i1 @usub_sat_uadd_sat_ule(i32 %x, i32 %y, i32 %z) { ; CHECK-LABEL: define i1 @usub_sat_uadd_sat_ule( ; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]], i32 [[Z:%.*]]) { -; CHECK-NEXT: [[OP1:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[X]], i32 [[Y]]) -; CHECK-NEXT: [[OP2:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X]], i32 [[Z]]) -; CHECK-NEXT: [[CMP:%.*]] = icmp ule i32 [[OP1]], [[OP2]] -; CHECK-NEXT: ret i1 [[CMP]] +; CHECK-NEXT: ret i1 true ; %op1 = call i32 @llvm.usub.sat(i32 %x, i32 %y) %op2 = call i32 @llvm.uadd.sat(i32 %x, i32 %z)