@@ -139,6 +139,7 @@ class VectorCombine {
139139 bool foldShuffleOfSelects (Instruction &I);
140140 bool foldShuffleOfCastops (Instruction &I);
141141 bool foldShuffleOfShuffles (Instruction &I);
142+ bool foldShufflesOfLengthChangingShuffles (Instruction &I);
142143 bool foldShuffleOfIntrinsics (Instruction &I);
143144 bool foldShuffleToIdentity (Instruction &I);
144145 bool foldShuffleFromReductions (Instruction &I);
@@ -2877,6 +2878,174 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
28772878 return true ;
28782879}
28792880
2881+ // / Try to convert a chain of length-preserving shuffles that are fed by
2882+ // / length-changing shuffles from the same source, e.g. a chain of length 3:
2883+ // /
2884+ // / "shuffle (shuffle (shuffle x, (shuffle y, undef)),
2885+ // / (shuffle y, undef)),
2886+ // (shuffle y, undef)"
2887+ // /
2888+ // / into a single shuffle fed by a length-changing shuffle:
2889+ // /
2890+ // / "shuffle x, (shuffle y, undef)"
2891+ // /
2892+ // / Such chains arise e.g. from folding extract/insert sequences.
2893+ bool VectorCombine::foldShufflesOfLengthChangingShuffles (Instruction &I) {
2894+ FixedVectorType *TrunkType = dyn_cast<FixedVectorType>(I.getType ());
2895+ if (!TrunkType)
2896+ return false ;
2897+
2898+ unsigned ChainLength = 0 ;
2899+ SmallVector<int > Mask;
2900+ SmallVector<int > YMask;
2901+ InstructionCost OldCost = 0 ;
2902+ InstructionCost NewCost = 0 ;
2903+ Value *Trunk = &I;
2904+ unsigned NumTrunkElts = TrunkType->getNumElements ();
2905+ FixedVectorType *YType = nullptr ;
2906+ Value *Y = nullptr ;
2907+
2908+ for (;;) {
2909+ // Match the current trunk against (commutations of) the pattern
2910+ // "shuffle trunk', (shuffle y, undef)"
2911+ ArrayRef<int > OuterMask;
2912+ Value *OuterV0, *OuterV1;
2913+ if (ChainLength != 0 && !Trunk->hasOneUse ())
2914+ break ;
2915+ if (!match (Trunk, m_Shuffle (m_Value (OuterV0), m_Value (OuterV1),
2916+ m_Mask (OuterMask))))
2917+ break ;
2918+ if (OuterV0->getType () != TrunkType) {
2919+ // This shuffle is not length-preserving, so it cannot be part of the
2920+ // chain.
2921+ break ;
2922+ }
2923+
2924+ ArrayRef<int > InnerMask0, InnerMask1;
2925+ Value *A0, *A1, *B0, *B1;
2926+ bool Match0 =
2927+ match (OuterV0, m_Shuffle (m_Value (A0), m_Value (B0), m_Mask (InnerMask0)));
2928+ bool Match1 =
2929+ match (OuterV1, m_Shuffle (m_Value (A1), m_Value (B1), m_Mask (InnerMask1)));
2930+ bool Match0Leaf = Match0 && A0->getType () != I.getType ();
2931+ bool Match1Leaf = Match1 && A1->getType () != I.getType ();
2932+ if (Match0Leaf == Match1Leaf) {
2933+ // Only handle the case of exactly one leaf in each step. The "two leaves"
2934+ // case is handled by foldShuffleOfShuffles.
2935+ break ;
2936+ }
2937+
2938+ SmallVector<int > CommutedOuterMask;
2939+ if (Match0Leaf) {
2940+ std::swap (OuterV0, OuterV1);
2941+ std::swap (InnerMask0, InnerMask1);
2942+ std::swap (A0, A1);
2943+ std::swap (B0, B1);
2944+ llvm::append_range (CommutedOuterMask, OuterMask);
2945+ for (int &M : CommutedOuterMask) {
2946+ if (M == PoisonMaskElem)
2947+ continue ;
2948+ if (M < (int )NumTrunkElts)
2949+ M += NumTrunkElts;
2950+ else
2951+ M -= NumTrunkElts;
2952+ }
2953+ OuterMask = CommutedOuterMask;
2954+ }
2955+ if (!OuterV1->hasOneUse ())
2956+ break ;
2957+
2958+ if (!isa<UndefValue>(A1)) {
2959+ if (!Y)
2960+ Y = A1;
2961+ else if (Y != A1)
2962+ break ;
2963+ }
2964+ if (!isa<UndefValue>(B1)) {
2965+ if (!Y)
2966+ Y = B1;
2967+ else if (Y != B1)
2968+ break ;
2969+ }
2970+
2971+ InstructionCost LocalOldCost =
2972+ TTI.getInstructionCost (cast<User>(Trunk), CostKind) +
2973+ TTI.getInstructionCost (cast<User>(OuterV1), CostKind);
2974+
2975+ // Handle the initial (start of chain) case.
2976+ if (!ChainLength) {
2977+ YType = cast<FixedVectorType>(A1->getType ());
2978+ Mask.assign (OuterMask);
2979+ YMask.assign (InnerMask1);
2980+ OldCost = NewCost = LocalOldCost;
2981+ Trunk = OuterV0;
2982+ ChainLength++;
2983+ continue ;
2984+ }
2985+
2986+ // For the non-root case, first attempt to combine masks.
2987+ SmallVector<int > NewYMask (YMask);
2988+ bool Valid = true ;
2989+ for (auto [CombinedM, LeafM] : llvm::zip (NewYMask, InnerMask1)) {
2990+ if (LeafM == -1 || CombinedM == LeafM)
2991+ continue ;
2992+ if (CombinedM == -1 ) {
2993+ CombinedM = LeafM;
2994+ } else {
2995+ Valid = false ;
2996+ break ;
2997+ }
2998+ }
2999+ if (!Valid)
3000+ break ;
3001+
3002+ SmallVector<int > NewMask;
3003+ NewMask.reserve (NumTrunkElts);
3004+ for (int M : Mask) {
3005+ if (M < 0 || M >= static_cast <int >(NumTrunkElts))
3006+ NewMask.push_back (M);
3007+ else
3008+ NewMask.push_back (OuterMask[M]);
3009+ }
3010+
3011+ // Break the chain if adding this new step complicates the shuffles such
3012+ // that it would increase the new cost by more than the old cost of this
3013+ // step.
3014+ InstructionCost LocalNewCost =
3015+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteSingleSrc, TrunkType,
3016+ YType, NewYMask, CostKind) +
3017+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, TrunkType,
3018+ TrunkType, NewMask, CostKind);
3019+
3020+ if (LocalNewCost >= NewCost && LocalOldCost < LocalNewCost - NewCost)
3021+ break ;
3022+
3023+ LLVM_DEBUG ({
3024+ if (ChainLength == 1 ) {
3025+ dbgs () << " Found chain of shuffles fed by length-changing shuffles: "
3026+ << I << ' \n ' ;
3027+ }
3028+ dbgs () << " next chain link: " << *Trunk << ' \n '
3029+ << " old cost: " << (OldCost + LocalOldCost)
3030+ << " new cost: " << LocalNewCost << ' \n ' ;
3031+ });
3032+
3033+ Mask = NewMask;
3034+ YMask = NewYMask;
3035+ OldCost += LocalOldCost;
3036+ NewCost = LocalNewCost;
3037+ Trunk = OuterV0;
3038+ ChainLength++;
3039+ }
3040+ if (ChainLength <= 1 )
3041+ return false ;
3042+
3043+ Value *Leaf = Builder.CreateShuffleVector (Y, PoisonValue::get (YType), YMask);
3044+ Value *Root = Builder.CreateShuffleVector (Trunk, Leaf, Mask);
3045+ replaceValue (I, *Root);
3046+ return true ;
3047+ }
3048+
28803049// / Try to convert
28813050// / "shuffle (intrinsic), (intrinsic)" into "intrinsic (shuffle), (shuffle)".
28823051bool VectorCombine::foldShuffleOfIntrinsics (Instruction &I) {
@@ -4718,6 +4887,8 @@ bool VectorCombine::run() {
47184887 return true ;
47194888 if (foldShuffleOfShuffles (I))
47204889 return true ;
4890+ if (foldShufflesOfLengthChangingShuffles (I))
4891+ return true ;
47214892 if (foldShuffleOfIntrinsics (I))
47224893 return true ;
47234894 if (foldSelectShuffle (I))
0 commit comments