diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h index 164b46b54890b..07a482d4f166a 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h @@ -182,6 +182,12 @@ m_scev_PtrToInt(const Op0_t &Op0) { return SCEVUnaryExpr_match(Op0); } +template +inline SCEVUnaryExpr_match +m_scev_Trunc(const Op0_t &Op0) { + return m_scev_Unary(Op0); +} + /// Match a binary SCEV. template (Op); - const SCEVZeroExtendExpr *ZExt = dyn_cast(Op); - if (!SExt && !ZExt) - return nullptr; - const SCEVTruncateExpr *Trunc = - SExt ? dyn_cast(SExt->getOperand()) - : dyn_cast(ZExt->getOperand()); - if (!Trunc) - return nullptr; - const SCEV *X = Trunc->getOperand(); - if (X != SymbolicPHI) - return nullptr; - Signed = SExt != nullptr; - return Trunc->getType(); + if (match(Op, m_scev_SExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) { + Signed = true; + return cast(Op)->getOperand()->getType(); + } + if (match(Op, m_scev_ZExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) { + Signed = false; + return cast(Op)->getOperand()->getType(); + } + return nullptr; } static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) { @@ -15428,20 +15423,18 @@ bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS, // Try to match 'zext (trunc A to iB) to iY', which is used // for URem with constant power-of-2 second operands. Make sure the size of // the operand A matches the size of the whole expressions. - if (const auto *ZExt = dyn_cast(Expr)) - if (const auto *Trunc = dyn_cast(ZExt->getOperand(0))) { - LHS = Trunc->getOperand(); - // Bail out if the type of the LHS is larger than the type of the - // expression for now. - if (getTypeSizeInBits(LHS->getType()) > - getTypeSizeInBits(Expr->getType())) - return false; - if (LHS->getType() != Expr->getType()) - LHS = getZeroExtendExpr(LHS, Expr->getType()); - RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1) - << getTypeSizeInBits(Trunc->getType())); - return true; - } + if (match(Expr, m_scev_ZExt(m_scev_Trunc(m_SCEV(LHS))))) { + Type *TruncTy = cast(Expr)->getOperand()->getType(); + // Bail out if the type of the LHS is larger than the type of the + // expression for now. + if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(Expr->getType())) + return false; + if (LHS->getType() != Expr->getType()) + LHS = getZeroExtendExpr(LHS, Expr->getType()); + RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1) + << getTypeSizeInBits(TruncTy)); + return true; + } const auto *Add = dyn_cast(Expr); if (Add == nullptr || Add->getNumOperands() != 2) return false;