Skip to content

Commit 53497e7

Browse files
author
Aidan
committed
refactored to more closley follow the IR PatternMatch implementation
1 parent 860bf81 commit 53497e7

File tree

2 files changed

+29
-41
lines changed

2 files changed

+29
-41
lines changed

llvm/include/llvm/CodeGen/SDPatternMatch.h

Lines changed: 25 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -548,49 +548,39 @@ struct BinaryOpc_match {
548548
};
549549

550550
/// Matching while capturing mask
551-
template <typename T0, typename T1> struct SDShuffle_match {
551+
template <typename T0, typename T1, typename T2> struct SDShuffle_match {
552552
T0 Op1;
553553
T1 Op2;
554+
T2 Mask;
554555

555-
ArrayRef<int> &CapturedMask;
556-
557-
// capturing mask
558-
SDShuffle_match(const T0 &Op1, const T1 &Op2, ArrayRef<int> &MaskRef)
559-
: Op1(Op1), Op2(Op2), CapturedMask(MaskRef) {}
556+
SDShuffle_match(const T0 &Op1, const T1 &Op2, const T2 &Mask)
557+
: Op1(Op1), Op2(Op2), Mask(Mask) {}
560558

561559
template <typename MatchContext>
562560
bool match(const MatchContext &Ctx, SDValue N) {
563561
if (auto *I = dyn_cast<ShuffleVectorSDNode>(N)) {
564-
if (Op1.match(Ctx, I->getOperand(0)) &&
565-
Op2.match(Ctx, I->getOperand(1))) {
566-
CapturedMask = I->getMask();
567-
return true;
568-
}
562+
return Op1.match(Ctx, I->getOperand(0)) &&
563+
Op2.match(Ctx, I->getOperand(1)) && Mask.match(I->getMask());
569564
}
570565
return false;
571566
}
572567
};
573568

574-
/// Matching against a specific match
575-
template <typename T0, typename T1> struct SDShuffle_maskMatch {
576-
T0 Op1;
577-
T1 Op2;
578-
ArrayRef<int> SpecificMask;
579-
580-
SDShuffle_maskMatch(const T0 &Op1, const T1 &Op2, ArrayRef<int> Mask)
581-
: Op1(Op1), Op2(Op2), SpecificMask(Mask) {}
582-
583-
template <typename MatchContext>
584-
bool match(const MatchContext &Ctx, SDValue N) {
585-
if (auto *I = dyn_cast<ShuffleVectorSDNode>(N)) {
586-
return Op1.match(Ctx, I->getOperand(0)) &&
587-
Op2.match(Ctx, I->getOperand(1)) &&
588-
std::equal(SpecificMask.begin(), SpecificMask.end(),
589-
I->getMask().begin(), I->getMask().end());
590-
}
591-
return false;
569+
struct m_Mask {
570+
ArrayRef<int> &MaskRef;
571+
m_Mask(ArrayRef<int> &MaskRef) : MaskRef(MaskRef) {}
572+
bool match(ArrayRef<int> Mask) {
573+
MaskRef = Mask;
574+
return true;
592575
}
593576
};
577+
578+
struct m_SpecificMask {
579+
ArrayRef<int> &MaskRef;
580+
m_SpecificMask(ArrayRef<int> &MaskRef) : MaskRef(MaskRef) {}
581+
bool match(ArrayRef<int> Mask) { return MaskRef == Mask; }
582+
};
583+
594584
template <typename LHS_P, typename RHS_P, typename Pred_t,
595585
bool Commutable = false, bool ExcludeChain = false>
596586
struct MaxMin_match {
@@ -842,15 +832,14 @@ inline BinaryOpc_match<LHS, RHS> m_FRem(const LHS &L, const RHS &R) {
842832
}
843833

844834
template <typename V1_t, typename V2_t>
845-
inline SDShuffle_match<V1_t, V2_t> m_Shuffle(const V1_t &v1, const V2_t &v2,
846-
ArrayRef<int> &mask) {
847-
return SDShuffle_match<V1_t, V2_t>(v1, v2, mask);
835+
inline BinaryOpc_match<V1_t, V2_t> m_Shuffle(const V1_t &v1, const V2_t &v2) {
836+
return BinaryOpc_match<V1_t, V2_t>(ISD::VECTOR_SHUFFLE, v1, v2);
848837
}
849838

850-
template <typename V1_t, typename V2_t>
851-
inline SDShuffle_maskMatch<V1_t, V2_t>
852-
m_ShuffleSpecificMask(const V1_t &v1, const V2_t &v2, ArrayRef<int> mask) {
853-
return SDShuffle_maskMatch<V1_t, V2_t>(v1, v2, mask);
839+
template <typename V1_t, typename V2_t, typename Mask_t>
840+
inline SDShuffle_match<V1_t, V2_t, Mask_t>
841+
m_Shuffle(const V1_t &v1, const V2_t &v2, const Mask_t &mask) {
842+
return SDShuffle_match<V1_t, V2_t, Mask_t>(v1, v2, mask);
854843
}
855844

856845
template <typename LHS, typename RHS>

llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ TEST_F(SelectionDAGPatternMatchTest, matchVecShuffle) {
125125
auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
126126
const std::array<int, 4> MaskData = {2, 0, 3, 1};
127127
const std::array<int, 4> OtherMaskData = {1, 2, 3, 4};
128+
ArrayRef<int> Mask(MaskData);
128129
ArrayRef<int> CapturedMask;
129130

130131
SDValue V0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, VInt32VT);
@@ -133,13 +134,11 @@ TEST_F(SelectionDAGPatternMatchTest, matchVecShuffle) {
133134
DAG->getVectorShuffle(VInt32VT, DL, V0, V1, MaskData);
134135

135136
using namespace SDPatternMatch;
137+
EXPECT_TRUE(sd_match(VecShuffleWithMask, m_Shuffle(m_Value(), m_Value())));
136138
EXPECT_TRUE(sd_match(VecShuffleWithMask,
137-
m_Shuffle(m_Value(), m_Value(), CapturedMask)));
139+
m_Shuffle(m_Value(), m_Value(), m_Mask(CapturedMask))));
138140
EXPECT_TRUE(sd_match(VecShuffleWithMask,
139-
m_ShuffleSpecificMask(m_Value(), m_Value(), MaskData)));
140-
EXPECT_FALSE(
141-
sd_match(VecShuffleWithMask,
142-
m_ShuffleSpecificMask(m_Value(), m_Value(), OtherMaskData)));
141+
m_Shuffle(m_Value(), m_Value(), m_SpecificMask(Mask))));
143142
EXPECT_TRUE(std::equal(MaskData.begin(), MaskData.end(), CapturedMask.begin(),
144143
CapturedMask.end()));
145144
EXPECT_FALSE(std::equal(OtherMaskData.begin(), OtherMaskData.end(),

0 commit comments

Comments
 (0)