diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h index d133610ef4f75..8818843a30625 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h +++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h @@ -461,6 +461,66 @@ m_c_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) { return m_BinaryOr(Op0, Op1); } +/// ICmp_match is a variant of BinaryRecipe_match that also binds the comparison +/// predicate. +template struct ICmp_match { + CmpPredicate *Predicate = nullptr; + Op0_t Op0; + Op1_t Op1; + + ICmp_match(CmpPredicate &Pred, const Op0_t &Op0, const Op1_t &Op1) + : Predicate(&Pred), Op0(Op0), Op1(Op1) {} + ICmp_match(const Op0_t &Op0, const Op1_t &Op1) : Op0(Op0), Op1(Op1) {} + + bool match(const VPValue *V) const { + auto *DefR = V->getDefiningRecipe(); + return DefR && match(DefR); + } + + bool match(const VPRecipeBase *V) const { + if (m_Binary(Op0, Op1).match(V)) { + if (Predicate) + *Predicate = cast(V)->getPredicate(); + return true; + } + return false; + } +}; + +/// SpecificICmp_match is a variant of ICmp_match that matches the comparison +/// predicate, instead of binding it. +template struct SpecificICmp_match { + const CmpPredicate Predicate; + Op0_t Op0; + Op1_t Op1; + + SpecificICmp_match(CmpPredicate Pred, const Op0_t &LHS, const Op1_t &RHS) + : Predicate(Pred), Op0(LHS), Op1(RHS) {} + + bool match(const VPValue *V) const { + CmpPredicate CurrentPred; + return ICmp_match(CurrentPred, Op0, Op1).match(V) && + CmpPredicate::getMatching(CurrentPred, Predicate); + } +}; + +template +inline ICmp_match m_ICmp(const Op0_t &Op0, const Op1_t &Op1) { + return ICmp_match(Op0, Op1); +} + +template +inline ICmp_match m_ICmp(CmpPredicate &Pred, const Op0_t &Op0, + const Op1_t &Op1) { + return ICmp_match(Pred, Op0, Op1); +} + +template +inline SpecificICmp_match +m_SpecificICmp(CmpPredicate MatchPred, const Op0_t &Op0, const Op1_t &Op1) { + return SpecificICmp_match(MatchPred, Op0, Op1); +} + template using GEPLikeRecipe_match = BinaryRecipe_matchuser_begin(), - m_Binary( - m_Specific(WideIV), - m_Broadcast(m_Specific(Plan.getOrCreateBackedgeTakenCount()))))) + if (!match(*WideIV->user_begin(), + m_ICmp(m_Specific(WideIV), + m_Broadcast( + m_Specific(Plan.getOrCreateBackedgeTakenCount()))))) continue; // Update IV operands and comparison bound to use new narrower type. @@ -1419,11 +1418,9 @@ static bool isConditionTrueViaVFAndUF(VPValue *Cond, VPlan &Plan, }); auto *CanIV = Plan.getCanonicalIV(); - if (!match(Cond, m_Binary( - m_Specific(CanIV->getBackedgeValue()), - m_Specific(&Plan.getVectorTripCount()))) || - cast(Cond->getDefiningRecipe())->getPredicate() != - CmpInst::ICMP_EQ) + if (!match(Cond, m_SpecificICmp(CmpInst::ICMP_EQ, + m_Specific(CanIV->getBackedgeValue()), + m_Specific(&Plan.getVectorTripCount())))) return false; // The compare checks CanIV + VFxUF == vector trip count. The vector trip @@ -1832,7 +1829,7 @@ void VPlanTransforms::truncateToMinimalBitwidths( VPW->dropPoisonGeneratingFlags(); if (OldResSizeInBits != NewResSizeInBits && - !match(&R, m_Binary(m_VPValue(), m_VPValue()))) { + !match(&R, m_ICmp(m_VPValue(), m_VPValue()))) { // Extend result to original width. auto *Ext = new VPWidenCastRecipe(Instruction::ZExt, ResultVPV, OldResTy); @@ -1841,9 +1838,8 @@ void VPlanTransforms::truncateToMinimalBitwidths( Ext->setOperand(0, ResultVPV); assert(OldResSizeInBits > NewResSizeInBits && "Nothing to shrink?"); } else { - assert( - match(&R, m_Binary(m_VPValue(), m_VPValue())) && - "Only ICmps should not need extending the result."); + assert(match(&R, m_ICmp(m_VPValue(), m_VPValue())) && + "Only ICmps should not need extending the result."); } assert(!isa(&R) && "stores cannot be narrowed");