diff --git a/llvm/include/llvm/Analysis/CmpInstAnalysis.h b/llvm/include/llvm/Analysis/CmpInstAnalysis.h index aeda58ac7535d..adcac0360632d 100644 --- a/llvm/include/llvm/Analysis/CmpInstAnalysis.h +++ b/llvm/include/llvm/Analysis/CmpInstAnalysis.h @@ -102,17 +102,21 @@ namespace llvm { }; /// Decompose an icmp into the form ((X & Mask) pred C) if possible. - /// Unless \p AllowNonZeroC is true, C will always be 0. + /// Unless \p AllowNonZeroC is true, C will always be 0. If \p + /// DecomposeAnd is specified, then, for equality predicates, this will + /// decompose bitmasking via `and`. std::optional decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred, - bool LookThroughTrunc = true, - bool AllowNonZeroC = false); + bool LookThroughTrunc = true, bool AllowNonZeroC = false, + bool DecomposeAnd = false); /// Decompose an icmp into the form ((X & Mask) pred C) if /// possible. Unless \p AllowNonZeroC is true, C will always be 0. + /// If \p DecomposeAnd is specified, then, for equality predicates, this + /// will decompose bitmasking via `and`. std::optional decomposeBitTest(Value *Cond, bool LookThroughTrunc = true, - bool AllowNonZeroC = false); + bool AllowNonZeroC = false, bool DecomposeAnd = false); } // end namespace llvm diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp index 5c0d1dd1c74b0..a1a79e5685f80 100644 --- a/llvm/lib/Analysis/CmpInstAnalysis.cpp +++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp @@ -75,11 +75,13 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy, std::optional llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred, - bool LookThruTrunc, bool AllowNonZeroC) { + bool LookThruTrunc, bool AllowNonZeroC, + bool DecomposeAnd) { using namespace PatternMatch; const APInt *OrigC; - if (!ICmpInst::isRelational(Pred) || !match(RHS, m_APIntAllowPoison(OrigC))) + if ((ICmpInst::isEquality(Pred) && !DecomposeAnd) || + !match(RHS, m_APIntAllowPoison(OrigC))) return std::nullopt; bool Inverted = false; @@ -128,7 +130,7 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred, return std::nullopt; } - case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULT: { // X -llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) { +std::optional llvm::decomposeBitTest(Value *Cond, + bool LookThruTrunc, + bool AllowNonZeroC, + bool DecomposeAnd) { using namespace PatternMatch; if (auto *ICmp = dyn_cast(Cond)) { // Don't allow pointers. Splat vectors are fine. @@ -175,7 +195,7 @@ llvm::decomposeBitTest(Value *Cond, bool LookThruTrunc, bool AllowNonZeroC) { return std::nullopt; return decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1), ICmp->getPredicate(), LookThruTrunc, - AllowNonZeroC); + AllowNonZeroC, DecomposeAnd); } Value *X; if (Cond->getType()->isIntOrIntVectorTy(1) && diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 19bf81137aab7..57d4459228d03 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -875,22 +875,16 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1, APInt &UnsetBitsMask) -> bool { CmpPredicate Pred = ICmp->getPredicate(); // Can it be decomposed into icmp eq (X & Mask), 0 ? - auto Res = - llvm::decomposeBitTestICmp(ICmp->getOperand(0), ICmp->getOperand(1), - Pred, /*LookThroughTrunc=*/false); + auto Res = llvm::decomposeBitTestICmp( + ICmp->getOperand(0), ICmp->getOperand(1), Pred, + /*LookThroughTrunc=*/false, /*AllowNonZeroC=*/false, + /*DecomposeAnd=*/true); if (Res && Res->Pred == ICmpInst::ICMP_EQ) { X = Res->X; UnsetBitsMask = Res->Mask; return true; } - // Is it icmp eq (X & Mask), 0 already? - const APInt *Mask; - if (match(ICmp, m_ICmp(Pred, m_And(m_Value(X), m_APInt(Mask)), m_Zero())) && - Pred == ICmpInst::ICMP_EQ) { - UnsetBitsMask = *Mask; - return true; - } return false; }; diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index eacc67cd5a475..8ddfe46e7aae0 100644 --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -2761,14 +2761,11 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX, m_LoopInvariant(m_Shl(m_One(), m_Value(BitPos)), CurLoop)))); }; - auto MatchConstantBitMask = [&]() { - return ICmpInst::isEquality(Pred) && match(CmpRHS, m_Zero()) && - match(CmpLHS, m_And(m_Value(CurrX), - m_CombineAnd(m_Value(BitMask), m_Power2()))) && - (BitPos = ConstantExpr::getExactLogBase2(cast(BitMask))); - }; + auto MatchDecomposableConstantBitMask = [&]() { - auto Res = llvm::decomposeBitTestICmp(CmpLHS, CmpRHS, Pred); + auto Res = llvm::decomposeBitTestICmp( + CmpLHS, CmpRHS, Pred, /*LookThroughTrunc=*/true, + /*AllowNonZeroC=*/false, /*DecomposeAnd=*/true); if (Res && Res->Mask.isPowerOf2()) { assert(ICmpInst::isEquality(Res->Pred)); Pred = Res->Pred; @@ -2780,8 +2777,7 @@ static bool detectShiftUntilBitTestIdiom(Loop *CurLoop, Value *&BaseX, return false; }; - if (!MatchVariableBitMask() && !MatchConstantBitMask() && - !MatchDecomposableConstantBitMask()) { + if (!MatchVariableBitMask() && !MatchDecomposableConstantBitMask()) { LLVM_DEBUG(dbgs() << DEBUG_TYPE " Bad backedge comparison.\n"); return false; }