Skip to content

Commit b365d5b

Browse files
author
Aidan
committed
proposed changes. Split functionality of m_shuffle. Fst varient captures mask as &arrayref. Snd matches specific contents as arrayref. Removed m_mask, updated tests
1 parent c577f7b commit b365d5b

File tree

2 files changed

+40
-27
lines changed

2 files changed

+40
-27
lines changed

llvm/include/llvm/CodeGen/SDPatternMatch.h

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -226,15 +226,6 @@ inline SwitchContext<MatchContext, Pattern> m_Context(const MatchContext &Ctx,
226226
return SwitchContext<MatchContext, Pattern>{Ctx, std::move(P)};
227227
}
228228

229-
struct m_Mask {
230-
ArrayRef<int> &MaskRef;
231-
m_Mask(ArrayRef<int> &MaskRef) : MaskRef(MaskRef) {}
232-
bool match(ArrayRef<int> Mask) {
233-
MaskRef = Mask;
234-
return true;
235-
}
236-
};
237-
238229
// === Value type ===
239230
struct ValueType_bind {
240231
EVT &BindVT;
@@ -549,21 +540,41 @@ struct BinaryOpc_match {
549540
}
550541
};
551542

552-
/// Matches shuffle.
543+
/// Matching while capturing mask
553544
template <typename T0, typename T1> struct SDShuffle_match {
554545
T0 Op1;
555546
T1 Op2;
556-
ArrayRef<int> Mask;
557547

558-
SDShuffle_match(const T0 &Op1, const T1 &Op2, const ArrayRef<int> &Mask)
559-
: Op1(Op1), Op2(Op2), Mask(Mask) {}
548+
const ArrayRef<int> *MaskRef;
549+
550+
// capturing mask
551+
SDShuffle_match(const T0 &Op1, const T1 &Op2, const ArrayRef<int> &MaskRef)
552+
: Op1(Op1), Op2(Op2), MaskRef(&MaskRef) {}
553+
554+
template <typename MatchContext>
555+
bool match(const MatchContext &Ctx, SDValue N) {
556+
if (auto *I = dyn_cast<ShuffleVectorSDNode>(N)) {
557+
return Op1.match(Ctx, I->getOperand(0)) &&
558+
Op2.match(Ctx, I->getOperand(1));
559+
}
560+
return false;
561+
}
562+
};
563+
564+
/// Matching against a specific match
565+
template <typename T0, typename T1> struct SDShuffle_maskMatch {
566+
T0 Op1;
567+
T1 Op2;
568+
ArrayRef<int> SpecificMask;
569+
570+
SDShuffle_maskMatch(const T0 &Op1, const T1 &Op2, const ArrayRef<int> Mask)
571+
: Op1(Op1), Op2(Op2), SpecificMask(Mask) {}
560572

561573
template <typename MatchContext>
562574
bool match(const MatchContext &Ctx, SDValue N) {
563575
if (auto *I = dyn_cast<ShuffleVectorSDNode>(N)) {
564576
return Op1.match(Ctx, I->getOperand(0)) &&
565-
Op2.match(Ctx, I->getOperand(1)) &&
566-
std::equal(Mask.begin(), Mask.end(), I->getMask().begin());
577+
Op2.match(Ctx, I->getOperand(1)) && I->getMask() == SpecificMask;
567578
}
568579
return false;
569580
}
@@ -818,15 +829,17 @@ inline BinaryOpc_match<LHS, RHS> m_FRem(const LHS &L, const RHS &R) {
818829
return BinaryOpc_match<LHS, RHS>(ISD::FREM, L, R);
819830
}
820831

821-
template <typename LHS, typename RHS>
822-
inline BinaryOpc_match<LHS, RHS> m_Shuffle(const LHS &v1, const RHS &v2) {
823-
return BinaryOpc_match<LHS, RHS>(ISD::VECTOR_SHUFFLE, v1, v2);
832+
template <typename V1_t, typename V2_t>
833+
inline SDShuffle_match<V1_t, V2_t> m_Shuffle(const V1_t &v1, const V2_t &v2,
834+
const ArrayRef<int> &maskRef) {
835+
return SDShuffle_match<V1_t, V2_t>(v1, v2, maskRef);
824836
}
825837

826838
template <typename V1_t, typename V2_t>
827-
inline SDShuffle_match<V1_t, V2_t> m_Shuffle(const V1_t &v1, const V2_t &v2,
828-
const ArrayRef<int> mask) {
829-
return SDShuffle_match<V1_t, V2_t>(v1, v2, mask);
839+
inline SDShuffle_maskMatch<V1_t, V2_t>
840+
m_ShuffleSpecificMask(const V1_t &v1, const V2_t &v2,
841+
const ArrayRef<int> mask) {
842+
return SDShuffle_maskMatch<V1_t, V2_t>(v1, v2, mask);
830843
}
831844

832845
// === Unary operations ===

llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,19 +124,19 @@ TEST_F(SelectionDAGPatternMatchTest, matchVecShuffle) {
124124
auto Int32VT = EVT::getIntegerVT(Context, 32);
125125
auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4);
126126
SmallVector<int, 4> MaskData = {2, 0, 3, 1};
127-
ArrayRef<int> Mask;
127+
ArrayRef<int> CapturedMask;
128128

129129
SDValue V0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, VInt32VT);
130130
SDValue V1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, VInt32VT);
131-
SDValue VecShuffleWithMask_0 =
131+
SDValue VecShuffleWithMask =
132132
DAG->getVectorShuffle(VInt32VT, DL, V0, V1, MaskData);
133133

134134
using namespace SDPatternMatch;
135+
EXPECT_TRUE(sd_match(VecShuffleWithMask,
136+
m_Shuffle(m_Value(V0), m_Value(V1), CapturedMask)));
135137
EXPECT_TRUE(
136-
sd_match(VecShuffleWithMask_0, m_Shuffle(m_Value(V0), m_Value(V1))));
137-
EXPECT_TRUE(sd_match(VecShuffleWithMask_0,
138-
m_Shuffle(m_Value(V0), m_Value(V1), Mask)));
139-
EXPECT_TRUE(std::equal(Mask.begin(), Mask.end(), MaskData.begin()));
138+
sd_match(VecShuffleWithMask,
139+
m_ShuffleSpecificMask(m_Value(V0), m_Value(V1), MaskData)));
140140
}
141141

142142
TEST_F(SelectionDAGPatternMatchTest, matchTernaryOp) {

0 commit comments

Comments
 (0)