@@ -226,15 +226,6 @@ inline SwitchContext<MatchContext, Pattern> m_Context(const MatchContext &Ctx,
226226 return SwitchContext<MatchContext, Pattern>{Ctx, std::move (P)};
227227}
228228
229- struct m_Mask {
230- ArrayRef<int > &MaskRef;
231- m_Mask (ArrayRef<int > &MaskRef) : MaskRef(MaskRef) {}
232- bool match (ArrayRef<int > Mask) {
233- MaskRef = Mask;
234- return true ;
235- }
236- };
237-
238229// === Value type ===
239230struct ValueType_bind {
240231 EVT &BindVT;
@@ -549,21 +540,41 @@ struct BinaryOpc_match {
549540 }
550541};
551542
552- // / Matches shuffle.
543+ // / Matching while capturing mask
553544template <typename T0, typename T1> struct SDShuffle_match {
554545 T0 Op1;
555546 T1 Op2;
556- ArrayRef<int > Mask;
557547
558- SDShuffle_match (const T0 &Op1, const T1 &Op2, const ArrayRef<int > &Mask)
559- : Op1(Op1), Op2(Op2), Mask(Mask) {}
548+ const ArrayRef<int > *MaskRef;
549+
550+ // capturing mask
551+ SDShuffle_match (const T0 &Op1, const T1 &Op2, const ArrayRef<int > &MaskRef)
552+ : Op1(Op1), Op2(Op2), MaskRef(&MaskRef) {}
553+
554+ template <typename MatchContext>
555+ bool match (const MatchContext &Ctx, SDValue N) {
556+ if (auto *I = dyn_cast<ShuffleVectorSDNode>(N)) {
557+ return Op1.match (Ctx, I->getOperand (0 )) &&
558+ Op2.match (Ctx, I->getOperand (1 ));
559+ }
560+ return false ;
561+ }
562+ };
563+
564+ // / Matching against a specific match
565+ template <typename T0, typename T1> struct SDShuffle_maskMatch {
566+ T0 Op1;
567+ T1 Op2;
568+ ArrayRef<int > SpecificMask;
569+
570+ SDShuffle_maskMatch (const T0 &Op1, const T1 &Op2, const ArrayRef<int > Mask)
571+ : Op1(Op1), Op2(Op2), SpecificMask(Mask) {}
560572
561573 template <typename MatchContext>
562574 bool match (const MatchContext &Ctx, SDValue N) {
563575 if (auto *I = dyn_cast<ShuffleVectorSDNode>(N)) {
564576 return Op1.match (Ctx, I->getOperand (0 )) &&
565- Op2.match (Ctx, I->getOperand (1 )) &&
566- std::equal (Mask.begin (), Mask.end (), I->getMask ().begin ());
577+ Op2.match (Ctx, I->getOperand (1 )) && I->getMask () == SpecificMask;
567578 }
568579 return false ;
569580 }
@@ -818,15 +829,17 @@ inline BinaryOpc_match<LHS, RHS> m_FRem(const LHS &L, const RHS &R) {
818829 return BinaryOpc_match<LHS, RHS>(ISD::FREM, L, R);
819830}
820831
821- template <typename LHS, typename RHS>
822- inline BinaryOpc_match<LHS, RHS> m_Shuffle (const LHS &v1, const RHS &v2) {
823- return BinaryOpc_match<LHS, RHS>(ISD::VECTOR_SHUFFLE, v1, v2);
832+ template <typename V1_t, typename V2_t>
833+ inline SDShuffle_match<V1_t, V2_t> m_Shuffle (const V1_t &v1, const V2_t &v2,
834+ const ArrayRef<int > &maskRef) {
835+ return SDShuffle_match<V1_t, V2_t>(v1, v2, maskRef);
824836}
825837
826838template <typename V1_t, typename V2_t>
827- inline SDShuffle_match<V1_t, V2_t> m_Shuffle (const V1_t &v1, const V2_t &v2,
828- const ArrayRef<int > mask) {
829- return SDShuffle_match<V1_t, V2_t>(v1, v2, mask);
839+ inline SDShuffle_maskMatch<V1_t, V2_t>
840+ m_ShuffleSpecificMask (const V1_t &v1, const V2_t &v2,
841+ const ArrayRef<int > mask) {
842+ return SDShuffle_maskMatch<V1_t, V2_t>(v1, v2, mask);
830843}
831844
832845// === Unary operations ===
0 commit comments