Skip to content

Commit 769c4f1

Browse files
committed
match select icmp
1 parent 923566a commit 769c4f1

File tree

1 file changed

+86
-8
lines changed

1 file changed

+86
-8
lines changed

llvm/include/llvm/CodeGen/SDPatternMatch.h

Lines changed: 86 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,80 @@ struct BinaryOpc_match {
542542
}
543543
};
544544

545+
template <typename LHS_P, typename RHS_P, typename Pred_t,
546+
bool Commutable = false, bool ExcludeChain = false>
547+
struct MaxMin_match {
548+
using PredType = Pred_t;
549+
LHS_P LHS;
550+
RHS_P RHS;
551+
552+
MaxMin_match(const LHS_P &L, const RHS_P &R) : LHS(L), RHS(R) {}
553+
554+
template <typename MatchContext>
555+
bool match(const MatchContext &Ctx, SDValue N) {
556+
if (sd_context_match(N, Ctx, m_Opc(ISD::SELECT))) {
557+
EffectiveOperands<ExcludeChain> EO_SELECT(N, Ctx);
558+
assert(EO_SELECT.Size == 3);
559+
SDValue Cond = N->getOperand(EO_SELECT.FirstIndex);
560+
SDValue TrueValue = N->getOperand(EO_SELECT.FirstIndex + 1);
561+
SDValue FalseValue = N->getOperand(EO_SELECT.FirstIndex + 2);
562+
563+
if (sd_context_match(Cond, Ctx, m_Opc(ISD::SETCC))) {
564+
EffectiveOperands<ExcludeChain> EO_SETCC(Cond, Ctx);
565+
assert(EO_SETCC.Size == 3);
566+
SDValue L = Cond->getOperand(EO_SETCC.FirstIndex);
567+
SDValue R = Cond->getOperand(EO_SETCC.FirstIndex + 1);
568+
CondCodeSDNode *CondNode =
569+
cast<CondCodeSDNode>(Cond->getOperand(EO_SETCC.FirstIndex + 2));
570+
571+
if ((TrueValue != L || FalseValue != R) &&
572+
(TrueValue != R || FalseValue != L)) {
573+
return false;
574+
}
575+
576+
ISD::CondCode Cond =
577+
TrueValue == L ? CondNode->get()
578+
: getSetCCInverse(CondNode->get(), L.getValueType());
579+
if (!Pred_t::match(Cond)) {
580+
return false;
581+
}
582+
return (LHS.match(Ctx, L) && RHS.match(Ctx, R)) ||
583+
(Commutable && LHS.match(Ctx, R) && RHS.match(Ctx, L));
584+
}
585+
}
586+
587+
return false;
588+
}
589+
};
590+
591+
// Helper class for identifying signed max predicates.
592+
struct smax_pred_ty {
593+
static bool match(ISD::CondCode Cond) {
594+
return Cond == ISD::CondCode::SETGT || Cond == ISD::CondCode::SETGE;
595+
}
596+
};
597+
598+
// Helper class for identifying unsigned max predicates.
599+
struct umax_pred_ty {
600+
static bool match(ISD::CondCode Cond) {
601+
return Cond == ISD::CondCode::SETUGT || Cond == ISD::CondCode::SETUGE;
602+
}
603+
};
604+
605+
// Helper class for identifying signed min predicates.
606+
struct smin_pred_ty {
607+
static bool match(ISD::CondCode Cond) {
608+
return Cond == ISD::CondCode::SETLT || Cond == ISD::CondCode::SETLE;
609+
}
610+
};
611+
612+
// Helper class for identifying unsigned min predicates.
613+
struct umin_pred_ty {
614+
static bool match(ISD::CondCode Cond) {
615+
return Cond == ISD::CondCode::SETULT || Cond == ISD::CondCode::SETULE;
616+
}
617+
};
618+
545619
template <typename LHS, typename RHS>
546620
inline BinaryOpc_match<LHS, RHS> m_BinOp(unsigned Opc, const LHS &L,
547621
const RHS &R) {
@@ -609,23 +683,27 @@ inline BinaryOpc_match<LHS, RHS, true> m_Xor(const LHS &L, const RHS &R) {
609683
}
610684

611685
template <typename LHS, typename RHS>
612-
inline BinaryOpc_match<LHS, RHS, true> m_SMin(const LHS &L, const RHS &R) {
613-
return BinaryOpc_match<LHS, RHS, true>(ISD::SMIN, L, R);
686+
inline auto m_SMin(const LHS &L, const RHS &R) {
687+
return m_AnyOf(BinaryOpc_match<LHS, RHS, true>(ISD::SMIN, L, R),
688+
MaxMin_match<LHS, RHS, smin_pred_ty, true>(L, R));
614689
}
615690

616691
template <typename LHS, typename RHS>
617-
inline BinaryOpc_match<LHS, RHS, true> m_SMax(const LHS &L, const RHS &R) {
618-
return BinaryOpc_match<LHS, RHS, true>(ISD::SMAX, L, R);
692+
inline auto m_SMax(const LHS &L, const RHS &R) {
693+
return m_AnyOf(BinaryOpc_match<LHS, RHS, true>(ISD::SMAX, L, R),
694+
MaxMin_match<LHS, RHS, smax_pred_ty, true>(L, R));
619695
}
620696

621697
template <typename LHS, typename RHS>
622-
inline BinaryOpc_match<LHS, RHS, true> m_UMin(const LHS &L, const RHS &R) {
623-
return BinaryOpc_match<LHS, RHS, true>(ISD::UMIN, L, R);
698+
inline auto m_UMin(const LHS &L, const RHS &R) {
699+
return m_AnyOf(BinaryOpc_match<LHS, RHS, true>(ISD::UMIN, L, R),
700+
MaxMin_match<LHS, RHS, umin_pred_ty, true>(L, R));
624701
}
625702

626703
template <typename LHS, typename RHS>
627-
inline BinaryOpc_match<LHS, RHS, true> m_UMax(const LHS &L, const RHS &R) {
628-
return BinaryOpc_match<LHS, RHS, true>(ISD::UMAX, L, R);
704+
inline auto m_UMax(const LHS &L, const RHS &R) {
705+
return m_AnyOf(BinaryOpc_match<LHS, RHS, true>(ISD::UMAX, L, R),
706+
MaxMin_match<LHS, RHS, umax_pred_ty, true>(L, R));
629707
}
630708

631709
template <typename LHS, typename RHS>

0 commit comments

Comments
 (0)