@@ -139,6 +139,7 @@ class VectorCombine {
139139 bool foldShuffleOfSelects (Instruction &I);
140140 bool foldShuffleOfCastops (Instruction &I);
141141 bool foldShuffleOfShuffles (Instruction &I);
142+ bool foldShufflesOfLengthChangingShuffles (Instruction &I);
142143 bool foldShuffleOfIntrinsics (Instruction &I);
143144 bool foldShuffleToIdentity (Instruction &I);
144145 bool foldShuffleFromReductions (Instruction &I);
@@ -2877,6 +2878,195 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
28772878 return true ;
28782879}
28792880
2881+ // / Try to convert a chain of length-preserving shuffles that are fed by
2882+ // / length-changing shuffles from the same source, e.g. a chain of length 3:
2883+ // /
2884+ // / "shuffle (shuffle (shuffle x, (shuffle y, undef)),
2885+ // / (shuffle y, undef)),
2886+ // (shuffle y, undef)"
2887+ // /
2888+ // / into a single shuffle fed by a length-changing shuffle:
2889+ // /
2890+ // / "shuffle x, (shuffle y, undef)"
2891+ // /
2892+ // / Such chains arise e.g. from folding extract/insert sequences.
2893+ bool VectorCombine::foldShufflesOfLengthChangingShuffles (Instruction &I) {
2894+ FixedVectorType *TrunkType = dyn_cast<FixedVectorType>(I.getType ());
2895+ if (!TrunkType)
2896+ return false ;
2897+
2898+ unsigned ChainLength = 0 ;
2899+ SmallVector<int > Mask;
2900+ SmallVector<int > YMask;
2901+ InstructionCost OldCost = 0 ;
2902+ InstructionCost NewCost = 0 ;
2903+ Value *Trunk = &I;
2904+ unsigned NumTrunkElts = TrunkType->getNumElements ();
2905+ Value *Y = nullptr ;
2906+
2907+ for (;;) {
2908+ // Match the current trunk against (commutations of) the pattern
2909+ // "shuffle trunk', (shuffle y, undef)"
2910+ ArrayRef<int > OuterMask;
2911+ Value *OuterV0, *OuterV1;
2912+ if (ChainLength != 0 && !Trunk->hasOneUse ())
2913+ break ;
2914+ if (!match (Trunk, m_Shuffle (m_Value (OuterV0), m_Value (OuterV1),
2915+ m_Mask (OuterMask))))
2916+ break ;
2917+ if (OuterV0->getType () != TrunkType) {
2918+ // This shuffle is not length-preserving, so it cannot be part of the
2919+ // chain.
2920+ break ;
2921+ }
2922+
2923+ ArrayRef<int > InnerMask0, InnerMask1;
2924+ Value *A0, *A1, *B0, *B1;
2925+ bool Match0 =
2926+ match (OuterV0, m_Shuffle (m_Value (A0), m_Value (B0), m_Mask (InnerMask0)));
2927+ bool Match1 =
2928+ match (OuterV1, m_Shuffle (m_Value (A1), m_Value (B1), m_Mask (InnerMask1)));
2929+ bool Match0Leaf = Match0 && A0->getType () != I.getType ();
2930+ bool Match1Leaf = Match1 && A1->getType () != I.getType ();
2931+ if (Match0Leaf == Match1Leaf) {
2932+ // Only handle the case of exactly one leaf in each step. The "two leaves"
2933+ // case is handled by foldShuffleOfShuffles.
2934+ break ;
2935+ }
2936+
2937+ SmallVector<int > CommutedOuterMask;
2938+ if (Match0Leaf) {
2939+ std::swap (OuterV0, OuterV1);
2940+ std::swap (InnerMask0, InnerMask1);
2941+ std::swap (A0, A1);
2942+ std::swap (B0, B1);
2943+ llvm::append_range (CommutedOuterMask, OuterMask);
2944+ for (int &M : CommutedOuterMask) {
2945+ if (M == PoisonMaskElem)
2946+ continue ;
2947+ if (M < (int )NumTrunkElts)
2948+ M += NumTrunkElts;
2949+ else
2950+ M -= NumTrunkElts;
2951+ }
2952+ OuterMask = CommutedOuterMask;
2953+ }
2954+ if (!OuterV1->hasOneUse ())
2955+ break ;
2956+
2957+ if (!isa<UndefValue>(A1)) {
2958+ if (!Y)
2959+ Y = A1;
2960+ else if (Y != A1)
2961+ break ;
2962+ }
2963+ if (!isa<UndefValue>(B1)) {
2964+ if (!Y)
2965+ Y = B1;
2966+ else if (Y != B1)
2967+ break ;
2968+ }
2969+
2970+ auto *YType = cast<FixedVectorType>(A1->getType ());
2971+ int NumLeafElts = YType->getNumElements ();
2972+ SmallVector<int > LocalYMask (InnerMask1);
2973+ for (int &M : LocalYMask) {
2974+ if (M >= NumLeafElts)
2975+ M -= NumLeafElts;
2976+ }
2977+
2978+ InstructionCost LocalOldCost =
2979+ TTI.getInstructionCost (cast<User>(Trunk), CostKind) +
2980+ TTI.getInstructionCost (cast<User>(OuterV1), CostKind);
2981+
2982+ // Handle the initial (start of chain) case.
2983+ if (!ChainLength) {
2984+ Mask.assign (OuterMask);
2985+ YMask.assign (LocalYMask);
2986+ OldCost = NewCost = LocalOldCost;
2987+ Trunk = OuterV0;
2988+ ChainLength++;
2989+ continue ;
2990+ }
2991+
2992+ // For the non-root case, first attempt to combine masks.
2993+ SmallVector<int > NewYMask (YMask);
2994+ bool Valid = true ;
2995+ for (auto [CombinedM, LeafM] : llvm::zip (NewYMask, LocalYMask)) {
2996+ if (LeafM == -1 || CombinedM == LeafM)
2997+ continue ;
2998+ if (CombinedM == -1 ) {
2999+ CombinedM = LeafM;
3000+ } else {
3001+ Valid = false ;
3002+ break ;
3003+ }
3004+ }
3005+ if (!Valid)
3006+ break ;
3007+
3008+ SmallVector<int > NewMask;
3009+ NewMask.reserve (NumTrunkElts);
3010+ for (int M : Mask) {
3011+ if (M < 0 || M >= static_cast <int >(NumTrunkElts))
3012+ NewMask.push_back (M);
3013+ else
3014+ NewMask.push_back (OuterMask[M]);
3015+ }
3016+
3017+ // Break the chain if adding this new step complicates the shuffles such
3018+ // that it would increase the new cost by more than the old cost of this
3019+ // step.
3020+ InstructionCost LocalNewCost =
3021+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteSingleSrc, TrunkType,
3022+ YType, NewYMask, CostKind) +
3023+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, TrunkType,
3024+ TrunkType, NewMask, CostKind);
3025+
3026+ if (LocalNewCost >= NewCost && LocalOldCost < LocalNewCost - NewCost)
3027+ break ;
3028+
3029+ LLVM_DEBUG ({
3030+ if (ChainLength == 1 ) {
3031+ dbgs () << " Found chain of shuffles fed by length-changing shuffles: "
3032+ << I << ' \n ' ;
3033+ }
3034+ dbgs () << " next chain link: " << *Trunk << ' \n '
3035+ << " old cost: " << (OldCost + LocalOldCost)
3036+ << " new cost: " << LocalNewCost << ' \n ' ;
3037+ });
3038+
3039+ Mask = NewMask;
3040+ YMask = NewYMask;
3041+ OldCost += LocalOldCost;
3042+ NewCost = LocalNewCost;
3043+ Trunk = OuterV0;
3044+ ChainLength++;
3045+ }
3046+ if (ChainLength <= 1 )
3047+ return false ;
3048+
3049+ if (llvm::all_of (Mask, [&](int M) {
3050+ return M < 0 || M >= static_cast <int >(NumTrunkElts);
3051+ })) {
3052+ // Produce a canonical simplified form if all elements are sourced from Y.
3053+ for (int &M : Mask) {
3054+ if (M >= static_cast <int >(NumTrunkElts))
3055+ M = YMask[M - NumTrunkElts];
3056+ }
3057+ Value *Root =
3058+ Builder.CreateShuffleVector (Y, PoisonValue::get (Y->getType ()), Mask);
3059+ replaceValue (I, *Root);
3060+ return true ;
3061+ }
3062+
3063+ Value *Leaf =
3064+ Builder.CreateShuffleVector (Y, PoisonValue::get (Y->getType ()), YMask);
3065+ Value *Root = Builder.CreateShuffleVector (Trunk, Leaf, Mask);
3066+ replaceValue (I, *Root);
3067+ return true ;
3068+ }
3069+
28803070// / Try to convert
28813071// / "shuffle (intrinsic), (intrinsic)" into "intrinsic (shuffle), (shuffle)".
28823072bool VectorCombine::foldShuffleOfIntrinsics (Instruction &I) {
@@ -4718,6 +4908,8 @@ bool VectorCombine::run() {
47184908 return true ;
47194909 if (foldShuffleOfShuffles (I))
47204910 return true ;
4911+ if (foldShufflesOfLengthChangingShuffles (I))
4912+ return true ;
47214913 if (foldShuffleOfIntrinsics (I))
47224914 return true ;
47234915 if (foldSelectShuffle (I))
0 commit comments