@@ -139,6 +139,7 @@ class VectorCombine {
139139 bool foldShuffleOfSelects (Instruction &I);
140140 bool foldShuffleOfCastops (Instruction &I);
141141 bool foldShuffleOfShuffles (Instruction &I);
142+ bool foldShufflesOfLengthChangingShuffles (Instruction &I);
142143 bool foldShuffleOfIntrinsics (Instruction &I);
143144 bool foldShuffleToIdentity (Instruction &I);
144145 bool foldShuffleFromReductions (Instruction &I);
@@ -2877,6 +2878,171 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
28772878 return true ;
28782879}
28792880
2881+ // / Try to convert a chain of length-preserving shuffles that are fed by
2882+ // / length-changing shuffles from the same source, e.g. a chain of length 3:
2883+ // /
2884+ // / "shuffle (shuffle (shuffle x, (shuffle y, undef)),
2885+ // / (shuffle y, undef)),
2886+ // (shuffle y, undef)"
2887+ // /
2888+ // / into a single shuffle fed by a length-changing shuffle:
2889+ // /
2890+ // / "shuffle x, (shuffle y, undef)"
2891+ // /
2892+ // / Such chains arise e.g. from folding extract/insert sequences.
2893+ bool VectorCombine::foldShufflesOfLengthChangingShuffles (Instruction &I) {
2894+ unsigned ChainLength = 0 ;
2895+ SmallVector<int > Mask;
2896+ SmallVector<int > YMask;
2897+ InstructionCost OldCost = 0 ;
2898+ InstructionCost NewCost = 0 ;
2899+ FixedVectorType *TrunkType = cast<FixedVectorType>(I.getType ());
2900+ Value *Trunk = &I;
2901+ unsigned NumTrunkElts = TrunkType->getNumElements ();
2902+ FixedVectorType *YType = nullptr ;
2903+ Value *Y = nullptr ;
2904+
2905+ for (;;) {
2906+ // Match the current trunk against (commutations of) the pattern
2907+ // "shuffle trunk', (shuffle y, undef)"
2908+ ArrayRef<int > OuterMask;
2909+ Value *OuterV0, *OuterV1;
2910+ if (ChainLength != 0 && !Trunk->hasOneUse ())
2911+ break ;
2912+ if (!match (Trunk, m_Shuffle (m_Value (OuterV0), m_Value (OuterV1),
2913+ m_Mask (OuterMask))))
2914+ break ;
2915+ if (OuterV0->getType () != TrunkType) {
2916+ // This shuffle is not length-preserving, so it cannot be part of the
2917+ // chain.
2918+ break ;
2919+ }
2920+
2921+ ArrayRef<int > InnerMask0, InnerMask1;
2922+ Value *A0, *A1, *B0, *B1;
2923+ bool Match0 =
2924+ match (OuterV0, m_Shuffle (m_Value (A0), m_Value (B0), m_Mask (InnerMask0)));
2925+ bool Match1 =
2926+ match (OuterV1, m_Shuffle (m_Value (A1), m_Value (B1), m_Mask (InnerMask1)));
2927+ bool Match0Leaf = Match0 && A0->getType () != I.getType ();
2928+ bool Match1Leaf = Match1 && A1->getType () != I.getType ();
2929+ if (Match0Leaf == Match1Leaf) {
2930+ // Only handle the case of exactly one leaf in each step. The "two leaves"
2931+ // case is handled by foldShuffleOfShuffles.
2932+ break ;
2933+ }
2934+
2935+ SmallVector<int > CommutedOuterMask;
2936+ if (Match0Leaf) {
2937+ std::swap (OuterV0, OuterV1);
2938+ std::swap (InnerMask0, InnerMask1);
2939+ std::swap (A0, A1);
2940+ std::swap (B0, B1);
2941+ llvm::append_range (CommutedOuterMask, OuterMask);
2942+ for (int &M : CommutedOuterMask) {
2943+ if (M == PoisonMaskElem)
2944+ continue ;
2945+ if (M < (int )NumTrunkElts)
2946+ M += NumTrunkElts;
2947+ else
2948+ M -= NumTrunkElts;
2949+ }
2950+ OuterMask = CommutedOuterMask;
2951+ }
2952+ if (!OuterV1->hasOneUse ())
2953+ break ;
2954+
2955+ if (!isa<PoisonValue>(A1)) {
2956+ if (!Y)
2957+ Y = A1;
2958+ else if (Y != A1)
2959+ break ;
2960+ }
2961+ if (!isa<PoisonValue>(B1)) {
2962+ if (!Y)
2963+ Y = B1;
2964+ else if (Y != B1)
2965+ break ;
2966+ }
2967+
2968+ InstructionCost LocalOldCost =
2969+ TTI.getInstructionCost (cast<User>(Trunk), CostKind) +
2970+ TTI.getInstructionCost (cast<User>(OuterV1), CostKind);
2971+
2972+ // Handle the initial (start of chain) case.
2973+ if (!ChainLength) {
2974+ YType = cast<FixedVectorType>(A1->getType ());
2975+ Mask.assign (OuterMask);
2976+ YMask.assign (InnerMask1);
2977+ OldCost = NewCost = LocalOldCost;
2978+ Trunk = OuterV0;
2979+ ChainLength++;
2980+ continue ;
2981+ }
2982+
2983+ // For the non-root case, first attempt to combine masks.
2984+ SmallVector<int > NewYMask (YMask);
2985+ bool Valid = true ;
2986+ for (auto [CombinedM, LeafM] : llvm::zip (NewYMask, InnerMask1)) {
2987+ if (LeafM == -1 || CombinedM == LeafM)
2988+ continue ;
2989+ if (CombinedM == -1 ) {
2990+ CombinedM = LeafM;
2991+ } else {
2992+ Valid = false ;
2993+ break ;
2994+ }
2995+ }
2996+ if (!Valid)
2997+ break ;
2998+
2999+ SmallVector<int > NewMask;
3000+ NewMask.reserve (NumTrunkElts);
3001+ for (int M : Mask) {
3002+ if (M < 0 || M >= (int )NumTrunkElts)
3003+ NewMask.push_back (M);
3004+ else
3005+ NewMask.push_back (OuterMask[M]);
3006+ }
3007+
3008+ // Break the chain if adding this new step complicates the shuffles such
3009+ // that it would increase the new cost by more than the old cost of this
3010+ // step.
3011+ InstructionCost LocalNewCost = 0 ;
3012+ LocalNewCost += TTI.getShuffleCost (TargetTransformInfo::SK_PermuteSingleSrc,
3013+ TrunkType, YType, NewYMask, CostKind);
3014+ LocalNewCost += TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc,
3015+ TrunkType, TrunkType, NewMask, CostKind);
3016+
3017+ if (LocalNewCost >= NewCost && LocalOldCost < LocalNewCost - NewCost)
3018+ break ;
3019+
3020+ LLVM_DEBUG ({
3021+ if (ChainLength == 1 ) {
3022+ dbgs () << " Found chain of shuffles fed by length-changing shuffles: "
3023+ << I << ' \n ' ;
3024+ }
3025+ dbgs () << " next chain link: " << *Trunk << ' \n '
3026+ << " old cost: " << (OldCost + LocalOldCost)
3027+ << " new cost: " << LocalNewCost << ' \n ' ;
3028+ });
3029+
3030+ Mask = NewMask;
3031+ YMask = NewYMask;
3032+ OldCost += LocalOldCost;
3033+ NewCost = LocalNewCost;
3034+ Trunk = OuterV0;
3035+ ChainLength++;
3036+ }
3037+ if (ChainLength <= 1 )
3038+ return false ;
3039+
3040+ Value *Leaf = Builder.CreateShuffleVector (Y, PoisonValue::get (YType), YMask);
3041+ Value *Root = Builder.CreateShuffleVector (Trunk, Leaf, Mask);
3042+ replaceValue (I, *Root);
3043+ return true ;
3044+ }
3045+
28803046// / Try to convert
28813047// / "shuffle (intrinsic), (intrinsic)" into "intrinsic (shuffle), (shuffle)".
28823048bool VectorCombine::foldShuffleOfIntrinsics (Instruction &I) {
@@ -4718,6 +4884,8 @@ bool VectorCombine::run() {
47184884 return true ;
47194885 if (foldShuffleOfShuffles (I))
47204886 return true ;
4887+ if (foldShufflesOfLengthChangingShuffles (I))
4888+ return true ;
47214889 if (foldShuffleOfIntrinsics (I))
47224890 return true ;
47234891 if (foldSelectShuffle (I))
0 commit comments