@@ -548,49 +548,39 @@ struct BinaryOpc_match {
548548};
549549
550550// / Matching while capturing mask
551- template <typename T0, typename T1> struct SDShuffle_match {
551+ template <typename T0, typename T1, typename T2 > struct SDShuffle_match {
552552 T0 Op1;
553553 T1 Op2;
554+ T2 Mask;
554555
555- ArrayRef<int > &CapturedMask;
556-
557- // capturing mask
558- SDShuffle_match (const T0 &Op1, const T1 &Op2, ArrayRef<int > &MaskRef)
559- : Op1(Op1), Op2(Op2), CapturedMask(MaskRef) {}
556+ SDShuffle_match (const T0 &Op1, const T1 &Op2, const T2 &Mask)
557+ : Op1(Op1), Op2(Op2), Mask(Mask) {}
560558
561559 template <typename MatchContext>
562560 bool match (const MatchContext &Ctx, SDValue N) {
563561 if (auto *I = dyn_cast<ShuffleVectorSDNode>(N)) {
564- if (Op1.match (Ctx, I->getOperand (0 )) &&
565- Op2.match (Ctx, I->getOperand (1 ))) {
566- CapturedMask = I->getMask ();
567- return true ;
568- }
562+ return Op1.match (Ctx, I->getOperand (0 )) &&
563+ Op2.match (Ctx, I->getOperand (1 )) && Mask.match (I->getMask ());
569564 }
570565 return false ;
571566 }
572567};
573568
574- // / Matching against a specific match
575- template <typename T0, typename T1> struct SDShuffle_maskMatch {
576- T0 Op1;
577- T1 Op2;
578- ArrayRef<int > SpecificMask;
579-
580- SDShuffle_maskMatch (const T0 &Op1, const T1 &Op2, ArrayRef<int > Mask)
581- : Op1(Op1), Op2(Op2), SpecificMask(Mask) {}
582-
583- template <typename MatchContext>
584- bool match (const MatchContext &Ctx, SDValue N) {
585- if (auto *I = dyn_cast<ShuffleVectorSDNode>(N)) {
586- return Op1.match (Ctx, I->getOperand (0 )) &&
587- Op2.match (Ctx, I->getOperand (1 )) &&
588- std::equal (SpecificMask.begin (), SpecificMask.end (),
589- I->getMask ().begin (), I->getMask ().end ());
590- }
591- return false ;
569+ struct m_Mask {
570+ ArrayRef<int > &MaskRef;
571+ m_Mask (ArrayRef<int > &MaskRef) : MaskRef(MaskRef) {}
572+ bool match (ArrayRef<int > Mask) {
573+ MaskRef = Mask;
574+ return true ;
592575 }
593576};
577+
578+ struct m_SpecificMask {
579+ ArrayRef<int > &MaskRef;
580+ m_SpecificMask (ArrayRef<int > &MaskRef) : MaskRef(MaskRef) {}
581+ bool match (ArrayRef<int > Mask) { return MaskRef == Mask; }
582+ };
583+
594584template <typename LHS_P, typename RHS_P, typename Pred_t,
595585 bool Commutable = false , bool ExcludeChain = false >
596586struct MaxMin_match {
@@ -842,15 +832,14 @@ inline BinaryOpc_match<LHS, RHS> m_FRem(const LHS &L, const RHS &R) {
842832}
843833
844834template <typename V1_t, typename V2_t>
845- inline SDShuffle_match<V1_t, V2_t> m_Shuffle (const V1_t &v1, const V2_t &v2,
846- ArrayRef<int > &mask) {
847- return SDShuffle_match<V1_t, V2_t>(v1, v2, mask);
835+ inline BinaryOpc_match<V1_t, V2_t> m_Shuffle (const V1_t &v1, const V2_t &v2) {
836+ return BinaryOpc_match<V1_t, V2_t>(ISD::VECTOR_SHUFFLE, v1, v2);
848837}
849838
850- template <typename V1_t, typename V2_t>
851- inline SDShuffle_maskMatch <V1_t, V2_t>
852- m_ShuffleSpecificMask (const V1_t &v1, const V2_t &v2, ArrayRef< int > mask) {
853- return SDShuffle_maskMatch <V1_t, V2_t>(v1, v2, mask);
839+ template <typename V1_t, typename V2_t, typename Mask_t >
840+ inline SDShuffle_match <V1_t, V2_t, Mask_t >
841+ m_Shuffle (const V1_t &v1, const V2_t &v2, const Mask_t & mask) {
842+ return SDShuffle_match <V1_t, V2_t, Mask_t >(v1, v2, mask);
854843}
855844
856845template <typename LHS, typename RHS>
0 commit comments