diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h index 871028de3163c..68198ec9b8a9f 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h @@ -95,6 +95,10 @@ inline bind_ty m_scev_Add(const SCEVAddExpr *&V) { return V; } +inline bind_ty m_scev_Mul(const SCEVMulExpr *&V) { + return V; +} + /// Match a specified const SCEV *. struct specificscev_ty { const SCEV *Expr; @@ -284,14 +288,10 @@ template struct SCEVURem_match { << SE.getTypeSizeInBits(TruncTy)); return Op0.match(LHS) && Op1.match(RHS); } - const auto *Add = dyn_cast(Expr); - if (Add == nullptr || Add->getNumOperands() != 2) - return false; - - const SCEV *A = Add->getOperand(1); - const auto *Mul = dyn_cast(Add->getOperand(0)); - if (Mul == nullptr) + const SCEV *A; + const SCEVMulExpr *Mul; + if (!SCEVPatternMatch::match(Expr, m_scev_Add(m_scev_Mul(Mul), m_SCEV(A)))) return false; const auto MatchURemWithDivisor = [&](const SCEV *B) { diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index a64b93d541943..442b9d1e8a30e 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -4623,17 +4623,11 @@ const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V, /// If Expr computes ~A, return A else return nullptr static const SCEV *MatchNotExpr(const SCEV *Expr) { - const SCEVAddExpr *Add = dyn_cast(Expr); - if (!Add || Add->getNumOperands() != 2 || - !Add->getOperand(0)->isAllOnesValue()) - return nullptr; - - const SCEVMulExpr *AddRHS = dyn_cast(Add->getOperand(1)); - if (!AddRHS || AddRHS->getNumOperands() != 2 || - !AddRHS->getOperand(0)->isAllOnesValue()) - return nullptr; - - return AddRHS->getOperand(1); + const SCEV *MulOp; + if (match(Expr, m_scev_Add(m_scev_AllOnes(), + m_scev_Mul(m_scev_AllOnes(), m_SCEV(MulOp))))) + return MulOp; + return nullptr; } /// Return a SCEV corresponding to ~V = -1-V @@ -12220,12 +12214,11 @@ ScalarEvolution::computeConstantDifference(const SCEV *More, const SCEV *Less) { // Try to match a common constant multiply. auto MatchConstMul = [](const SCEV *S) -> std::optional> { - auto *M = dyn_cast(S); - if (!M || M->getNumOperands() != 2 || - !isa(M->getOperand(0))) - return std::nullopt; - return { - {M->getOperand(1), cast(M->getOperand(0))->getAPInt()}}; + const APInt *C; + const SCEV *Op; + if (match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op)))) + return {{Op, *C}}; + return std::nullopt; }; if (auto MatchedMore = MatchConstMul(More)) { if (auto MatchedLess = MatchConstMul(Less)) {