diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 2d7524e8018b2..4b7793f6e010b 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -1320,71 +1320,52 @@ static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1, Value *InstCombinerImpl::foldAndOrOfICmpsUsingRanges(ICmpInst *ICmp1, ICmpInst *ICmp2, bool IsAnd) { - CmpPredicate Pred1, Pred2; - Value *V1, *V2; - const APInt *C1, *C2; - if (!match(ICmp1, m_ICmp(Pred1, m_Value(V1), m_APInt(C1))) || - !match(ICmp2, m_ICmp(Pred2, m_Value(V2), m_APInt(C2)))) - return nullptr; - - // Look through add of a constant offset on V1, V2, or both operands. This - // allows us to interpret the V + C' < C'' range idiom into a proper range. - const APInt *Offset1 = nullptr, *Offset2 = nullptr; - if (V1 != V2) { - Value *X; - if (match(V1, m_Add(m_Value(X), m_APInt(Offset1)))) - V1 = X; - if (match(V2, m_Add(m_Value(X), m_APInt(Offset2)))) - V2 = X; - } - - // Look through and with a negative power of 2 mask on V1 or V2. This - // detects idioms of the form `(x == A) || ((x & Mask) == A + 1)` where A + 1 - // is aligned to the mask and A + 1 >= |Mask|. This pattern corresponds to a - // contiguous range check, which can be folded into an addition and compare. - // The same applies for `(x != A) && ((x & Mask) != A + 1)`. - auto AreContiguousRangePredicates = [](CmpPredicate Pred1, CmpPredicate Pred2, - bool IsAnd) { - if (IsAnd) - return Pred1 == ICmpInst::ICMP_NE && Pred2 == ICmpInst::ICMP_NE; - return Pred1 == ICmpInst::ICMP_EQ && Pred2 == ICmpInst::ICMP_EQ; - }; - const APInt *Mask1 = nullptr, *Mask2 = nullptr; - bool MatchedAnd1 = false, MatchedAnd2 = false; - if (V1 != V2 && AreContiguousRangePredicates(Pred1, Pred2, IsAnd)) { + // Return (V, CR) for a range check idiom V in CR. + auto MatchExactRangeCheck = + [](ICmpInst *ICmp) -> std::optional> { + const APInt *C; + if (!match(ICmp->getOperand(1), m_APInt(C))) + return std::nullopt; + Value *LHS = ICmp->getOperand(0); + CmpPredicate Pred = ICmp->getPredicate(); Value *X; - if (match(V1, m_OneUse(m_And(m_Value(X), m_NegatedPower2(Mask1)))) && - C1->getBitWidth() == C2->getBitWidth() && *C1 == *C2 + 1 && - C1->uge(Mask1->abs()) && C1->isPowerOf2()) { - MatchedAnd1 = true; - V1 = X; + // Match (x & NegPow2) ==/!= C + const APInt *Mask; + if (ICmpInst::isEquality(Pred) && + match(LHS, m_OneUse(m_And(m_Value(X), m_NegatedPower2(Mask)))) && + C->countr_zero() >= Mask->countr_zero()) { + ConstantRange CR(*C, *C - *Mask); + if (Pred == ICmpInst::ICMP_NE) + CR = CR.inverse(); + return std::make_pair(X, CR); } - if (match(V2, m_OneUse(m_And(m_Value(X), m_NegatedPower2(Mask2)))) && - C1->getBitWidth() == C2->getBitWidth() && *C2 == *C1 + 1 && - C2->uge(Mask2->abs()) && C2->isPowerOf2()) { - MatchedAnd2 = true; - V2 = X; - } - } + ConstantRange CR = ConstantRange::makeExactICmpRegion(Pred, *C); + // Match (add X, C1) pred C + // TODO: investigate whether we should apply the one-use check on m_AddLike. + const APInt *C1; + if (match(LHS, m_AddLike(m_Value(X), m_APInt(C1)))) + return std::make_pair(X, CR.subtract(*C1)); + return std::make_pair(LHS, CR); + }; + + auto RC1 = MatchExactRangeCheck(ICmp1); + if (!RC1) + return nullptr; + + auto RC2 = MatchExactRangeCheck(ICmp2); + if (!RC2) + return nullptr; + auto &[V1, CR1] = *RC1; + auto &[V2, CR2] = *RC2; if (V1 != V2) return nullptr; - ConstantRange CR1 = - MatchedAnd1 - ? ConstantRange(*C1, *C1 - *Mask1) - : ConstantRange::makeExactICmpRegion( - IsAnd ? ICmpInst::getInverseCmpPredicate(Pred1) : Pred1, *C1); - if (Offset1) - CR1 = CR1.subtract(*Offset1); - - ConstantRange CR2 = - MatchedAnd2 - ? ConstantRange(*C2, *C2 - *Mask2) - : ConstantRange::makeExactICmpRegion( - IsAnd ? ICmpInst::getInverseCmpPredicate(Pred2) : Pred2, *C2); - if (Offset2) - CR2 = CR2.subtract(*Offset2); + // For 'and', we use the De Morgan's Laws to simplify the implementation. + if (IsAnd) { + CR1 = CR1.inverse(); + CR2 = CR2.inverse(); + } Type *Ty = V1->getType(); Value *NewV = V1; diff --git a/llvm/test/Transforms/InstCombine/and-or-icmps.ll b/llvm/test/Transforms/InstCombine/and-or-icmps.ll index 553c7ac5af0e9..290e344acb980 100644 --- a/llvm/test/Transforms/InstCombine/and-or-icmps.ll +++ b/llvm/test/Transforms/InstCombine/and-or-icmps.ll @@ -3672,13 +3672,11 @@ define i1 @neg_or_icmp_eq_double_and_pow2(i32 %x) { ret i1 %ret } -define i1 @neg_select_icmp_eq_and_pow2(i32 %x) { -; CHECK-LABEL: @neg_select_icmp_eq_and_pow2( -; CHECK-NEXT: [[ICMP1:%.*]] = icmp sgt i32 [[X:%.*]], 127 -; CHECK-NEXT: [[AND:%.*]] = and i32 [[X]], -32 -; CHECK-NEXT: [[ICMP2:%.*]] = icmp eq i32 [[AND]], 128 -; CHECK-NEXT: [[TMP1:%.*]] = and i1 [[ICMP1]], [[ICMP2]] -; CHECK-NEXT: ret i1 [[TMP1]] +define i1 @implied_select_icmp_eq_and_pow2(i32 %x) { +; CHECK-LABEL: @implied_select_icmp_eq_and_pow2( +; CHECK-NEXT: [[TMP1:%.*]] = and i32 [[X:%.*]], -32 +; CHECK-NEXT: [[TMP2:%.*]] = icmp eq i32 [[TMP1]], 128 +; CHECK-NEXT: ret i1 [[TMP2]] ; %icmp1 = icmp sgt i32 %x, 127 %and = and i32 %x, -32 @@ -3686,3 +3684,40 @@ define i1 @neg_select_icmp_eq_and_pow2(i32 %x) { %1 = select i1 %icmp1, i1 %icmp2, i1 false ret i1 %1 } + +define i1 @implied_range_check(i8 %a) { +; CHECK-LABEL: @implied_range_check( +; CHECK-NEXT: [[MASKED:%.*]] = and i8 [[A:%.*]], -2 +; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i8 [[MASKED]], 2 +; CHECK-NEXT: ret i1 [[CMP2]] +; + %cmp1 = icmp ult i8 %a, 5 + %masked = and i8 %a, -2 + %cmp2 = icmp eq i8 %masked, 2 + %and = and i1 %cmp1, %cmp2 + ret i1 %and +} + +define i1 @merge_range_check_and(i8 %a) { +; CHECK-LABEL: @merge_range_check_and( +; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i8 [[MASKED:%.*]], 2 +; CHECK-NEXT: ret i1 [[CMP2]] +; + %cmp1 = icmp ult i8 %a, 3 + %masked = and i8 %a, -2 + %cmp2 = icmp eq i8 %masked, 2 + %and = and i1 %cmp1, %cmp2 + ret i1 %and +} + +define i1 @merge_range_check_or(i8 %a) { +; CHECK-LABEL: @merge_range_check_or( +; CHECK-NEXT: [[AND:%.*]] = icmp ult i8 [[A:%.*]], 4 +; CHECK-NEXT: ret i1 [[AND]] +; + %cmp1 = icmp ult i8 %a, 3 + %masked = and i8 %a, -2 + %cmp2 = icmp eq i8 %masked, 2 + %and = or i1 %cmp1, %cmp2 + ret i1 %and +} diff --git a/llvm/test/Transforms/InstCombine/icmp-range.ll b/llvm/test/Transforms/InstCombine/icmp-range.ll index 97ed552b9a6da..1970694cf9c42 100644 --- a/llvm/test/Transforms/InstCombine/icmp-range.ll +++ b/llvm/test/Transforms/InstCombine/icmp-range.ll @@ -1678,10 +1678,7 @@ define i1 @icmp_slt_sext_ne_otherwise_nofold(i32 %a) { ; tests from PR59555 define i1 @isFloat(i64 %0) { ; CHECK-LABEL: @isFloat( -; CHECK-NEXT: [[TMP2:%.*]] = icmp ugt i64 [[TMP0:%.*]], 281474976710655 -; CHECK-NEXT: [[TMP3:%.*]] = and i64 [[TMP0]], -281474976710656 -; CHECK-NEXT: [[TMP4:%.*]] = icmp ne i64 [[TMP3]], 281474976710656 -; CHECK-NEXT: [[TMP5:%.*]] = and i1 [[TMP2]], [[TMP4]] +; CHECK-NEXT: [[TMP5:%.*]] = icmp ugt i64 [[TMP0:%.*]], 562949953421311 ; CHECK-NEXT: ret i1 [[TMP5]] ; %2 = icmp ugt i64 %0, 281474976710655