Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,66 @@ m_c_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) {
return m_BinaryOr<Op0_t, Op1_t, /*Commutative*/ true>(Op0, Op1);
}

/// ICmp_match is a variant of BinaryRecipe_match that also binds the comparison
/// predicate.
template <typename Op0_t, typename Op1_t> struct ICmp_match {
CmpPredicate *Predicate = nullptr;
Op0_t Op0;
Op1_t Op1;

ICmp_match(CmpPredicate &Pred, const Op0_t &Op0, const Op1_t &Op1)
: Predicate(&Pred), Op0(Op0), Op1(Op1) {}
ICmp_match(const Op0_t &Op0, const Op1_t &Op1) : Op0(Op0), Op1(Op1) {}

bool match(const VPValue *V) const {
auto *DefR = V->getDefiningRecipe();
return DefR && match(DefR);
}

bool match(const VPRecipeBase *V) const {
if (m_Binary<Instruction::ICmp>(Op0, Op1).match(V)) {
if (Predicate)
*Predicate = cast<VPRecipeWithIRFlags>(V)->getPredicate();
return true;
}
return false;
}
};

/// SpecificICmp_match is a variant of ICmp_match that matches the comparison
/// predicate, instead of binding it.
template <typename Op0_t, typename Op1_t> struct SpecificICmp_match {
const CmpPredicate Predicate;
Op0_t Op0;
Op1_t Op1;

SpecificICmp_match(CmpPredicate Pred, const Op0_t &LHS, const Op1_t &RHS)
: Predicate(Pred), Op0(LHS), Op1(RHS) {}

bool match(const VPValue *V) const {
CmpPredicate CurrentPred;
return ICmp_match<Op0_t, Op1_t>(CurrentPred, Op0, Op1).match(V) &&
CmpPredicate::getMatching(CurrentPred, Predicate);
}
};

template <typename Op0_t, typename Op1_t>
inline ICmp_match<Op0_t, Op1_t> m_ICmp(const Op0_t &Op0, const Op1_t &Op1) {
return ICmp_match<Op0_t, Op1_t>(Op0, Op1);
}

template <typename Op0_t, typename Op1_t>
inline ICmp_match<Op0_t, Op1_t> m_ICmp(CmpPredicate &Pred, const Op0_t &Op0,
const Op1_t &Op1) {
return ICmp_match<Op0_t, Op1_t>(Pred, Op0, Op1);
}

template <typename Op0_t, typename Op1_t>
inline SpecificICmp_match<Op0_t, Op1_t>
m_SpecificICmp(CmpPredicate MatchPred, const Op0_t &Op0, const Op1_t &Op1) {
return SpecificICmp_match<Op0_t, Op1_t>(MatchPred, Op0, Op1);
}

template <typename Op0_t, typename Op1_t>
using GEPLikeRecipe_match =
BinaryRecipe_match<Op0_t, Op1_t, Instruction::GetElementPtr, false,
Expand Down
24 changes: 10 additions & 14 deletions llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1382,11 +1382,10 @@ static bool optimizeVectorInductionWidthForTCAndVFUF(VPlan &Plan,

// Currently only handle cases where the single user is a header-mask
// comparison with the backedge-taken-count.
if (!match(
*WideIV->user_begin(),
m_Binary<Instruction::ICmp>(
m_Specific(WideIV),
m_Broadcast(m_Specific(Plan.getOrCreateBackedgeTakenCount())))))
if (!match(*WideIV->user_begin(),
m_ICmp(m_Specific(WideIV),
m_Broadcast(
m_Specific(Plan.getOrCreateBackedgeTakenCount())))))
continue;

// Update IV operands and comparison bound to use new narrower type.
Expand Down Expand Up @@ -1419,11 +1418,9 @@ static bool isConditionTrueViaVFAndUF(VPValue *Cond, VPlan &Plan,
});

auto *CanIV = Plan.getCanonicalIV();
if (!match(Cond, m_Binary<Instruction::ICmp>(
m_Specific(CanIV->getBackedgeValue()),
m_Specific(&Plan.getVectorTripCount()))) ||
cast<VPRecipeWithIRFlags>(Cond->getDefiningRecipe())->getPredicate() !=
CmpInst::ICMP_EQ)
if (!match(Cond, m_SpecificICmp(CmpInst::ICMP_EQ,
m_Specific(CanIV->getBackedgeValue()),
m_Specific(&Plan.getVectorTripCount()))))
return false;

// The compare checks CanIV + VFxUF == vector trip count. The vector trip
Expand Down Expand Up @@ -1832,7 +1829,7 @@ void VPlanTransforms::truncateToMinimalBitwidths(
VPW->dropPoisonGeneratingFlags();

if (OldResSizeInBits != NewResSizeInBits &&
!match(&R, m_Binary<Instruction::ICmp>(m_VPValue(), m_VPValue()))) {
!match(&R, m_ICmp(m_VPValue(), m_VPValue()))) {
// Extend result to original width.
auto *Ext =
new VPWidenCastRecipe(Instruction::ZExt, ResultVPV, OldResTy);
Expand All @@ -1841,9 +1838,8 @@ void VPlanTransforms::truncateToMinimalBitwidths(
Ext->setOperand(0, ResultVPV);
assert(OldResSizeInBits > NewResSizeInBits && "Nothing to shrink?");
} else {
assert(
match(&R, m_Binary<Instruction::ICmp>(m_VPValue(), m_VPValue())) &&
"Only ICmps should not need extending the result.");
assert(match(&R, m_ICmp(m_VPValue(), m_VPValue())) &&
"Only ICmps should not need extending the result.");
}

assert(!isa<VPWidenStoreRecipe>(&R) && "stores cannot be narrowed");
Expand Down