@@ -140,6 +140,7 @@ class VectorCombine {
140140 bool foldShuffleOfCastops (Instruction &I);
141141 bool foldShuffleOfShuffles (Instruction &I);
142142 bool foldPermuteOfIntrinsic (Instruction &I);
143+ bool foldShufflesOfLengthChangingShuffles (Instruction &I);
143144 bool foldShuffleOfIntrinsics (Instruction &I);
144145 bool foldShuffleToIdentity (Instruction &I);
145146 bool foldShuffleFromReductions (Instruction &I);
@@ -2878,6 +2879,195 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
28782879 return true ;
28792880}
28802881
2882+ // / Try to convert a chain of length-preserving shuffles that are fed by
2883+ // / length-changing shuffles from the same source, e.g. a chain of length 3:
2884+ // /
2885+ // / "shuffle (shuffle (shuffle x, (shuffle y, undef)),
2886+ // / (shuffle y, undef)),
2887+ // (shuffle y, undef)"
2888+ // /
2889+ // / into a single shuffle fed by a length-changing shuffle:
2890+ // /
2891+ // / "shuffle x, (shuffle y, undef)"
2892+ // /
2893+ // / Such chains arise e.g. from folding extract/insert sequences.
2894+ bool VectorCombine::foldShufflesOfLengthChangingShuffles (Instruction &I) {
2895+ FixedVectorType *TrunkType = dyn_cast<FixedVectorType>(I.getType ());
2896+ if (!TrunkType)
2897+ return false ;
2898+
2899+ unsigned ChainLength = 0 ;
2900+ SmallVector<int > Mask;
2901+ SmallVector<int > YMask;
2902+ InstructionCost OldCost = 0 ;
2903+ InstructionCost NewCost = 0 ;
2904+ Value *Trunk = &I;
2905+ unsigned NumTrunkElts = TrunkType->getNumElements ();
2906+ Value *Y = nullptr ;
2907+
2908+ for (;;) {
2909+ // Match the current trunk against (commutations of) the pattern
2910+ // "shuffle trunk', (shuffle y, undef)"
2911+ ArrayRef<int > OuterMask;
2912+ Value *OuterV0, *OuterV1;
2913+ if (ChainLength != 0 && !Trunk->hasOneUse ())
2914+ break ;
2915+ if (!match (Trunk, m_Shuffle (m_Value (OuterV0), m_Value (OuterV1),
2916+ m_Mask (OuterMask))))
2917+ break ;
2918+ if (OuterV0->getType () != TrunkType) {
2919+ // This shuffle is not length-preserving, so it cannot be part of the
2920+ // chain.
2921+ break ;
2922+ }
2923+
2924+ ArrayRef<int > InnerMask0, InnerMask1;
2925+ Value *A0, *A1, *B0, *B1;
2926+ bool Match0 =
2927+ match (OuterV0, m_Shuffle (m_Value (A0), m_Value (B0), m_Mask (InnerMask0)));
2928+ bool Match1 =
2929+ match (OuterV1, m_Shuffle (m_Value (A1), m_Value (B1), m_Mask (InnerMask1)));
2930+ bool Match0Leaf = Match0 && A0->getType () != I.getType ();
2931+ bool Match1Leaf = Match1 && A1->getType () != I.getType ();
2932+ if (Match0Leaf == Match1Leaf) {
2933+ // Only handle the case of exactly one leaf in each step. The "two leaves"
2934+ // case is handled by foldShuffleOfShuffles.
2935+ break ;
2936+ }
2937+
2938+ SmallVector<int > CommutedOuterMask;
2939+ if (Match0Leaf) {
2940+ std::swap (OuterV0, OuterV1);
2941+ std::swap (InnerMask0, InnerMask1);
2942+ std::swap (A0, A1);
2943+ std::swap (B0, B1);
2944+ llvm::append_range (CommutedOuterMask, OuterMask);
2945+ for (int &M : CommutedOuterMask) {
2946+ if (M == PoisonMaskElem)
2947+ continue ;
2948+ if (M < (int )NumTrunkElts)
2949+ M += NumTrunkElts;
2950+ else
2951+ M -= NumTrunkElts;
2952+ }
2953+ OuterMask = CommutedOuterMask;
2954+ }
2955+ if (!OuterV1->hasOneUse ())
2956+ break ;
2957+
2958+ if (!isa<UndefValue>(A1)) {
2959+ if (!Y)
2960+ Y = A1;
2961+ else if (Y != A1)
2962+ break ;
2963+ }
2964+ if (!isa<UndefValue>(B1)) {
2965+ if (!Y)
2966+ Y = B1;
2967+ else if (Y != B1)
2968+ break ;
2969+ }
2970+
2971+ auto *YType = cast<FixedVectorType>(A1->getType ());
2972+ int NumLeafElts = YType->getNumElements ();
2973+ SmallVector<int > LocalYMask (InnerMask1);
2974+ for (int &M : LocalYMask) {
2975+ if (M >= NumLeafElts)
2976+ M -= NumLeafElts;
2977+ }
2978+
2979+ InstructionCost LocalOldCost =
2980+ TTI.getInstructionCost (cast<User>(Trunk), CostKind) +
2981+ TTI.getInstructionCost (cast<User>(OuterV1), CostKind);
2982+
2983+ // Handle the initial (start of chain) case.
2984+ if (!ChainLength) {
2985+ Mask.assign (OuterMask);
2986+ YMask.assign (LocalYMask);
2987+ OldCost = NewCost = LocalOldCost;
2988+ Trunk = OuterV0;
2989+ ChainLength++;
2990+ continue ;
2991+ }
2992+
2993+ // For the non-root case, first attempt to combine masks.
2994+ SmallVector<int > NewYMask (YMask);
2995+ bool Valid = true ;
2996+ for (auto [CombinedM, LeafM] : llvm::zip (NewYMask, LocalYMask)) {
2997+ if (LeafM == -1 || CombinedM == LeafM)
2998+ continue ;
2999+ if (CombinedM == -1 ) {
3000+ CombinedM = LeafM;
3001+ } else {
3002+ Valid = false ;
3003+ break ;
3004+ }
3005+ }
3006+ if (!Valid)
3007+ break ;
3008+
3009+ SmallVector<int > NewMask;
3010+ NewMask.reserve (NumTrunkElts);
3011+ for (int M : Mask) {
3012+ if (M < 0 || M >= static_cast <int >(NumTrunkElts))
3013+ NewMask.push_back (M);
3014+ else
3015+ NewMask.push_back (OuterMask[M]);
3016+ }
3017+
3018+ // Break the chain if adding this new step complicates the shuffles such
3019+ // that it would increase the new cost by more than the old cost of this
3020+ // step.
3021+ InstructionCost LocalNewCost =
3022+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteSingleSrc, TrunkType,
3023+ YType, NewYMask, CostKind) +
3024+ TTI.getShuffleCost (TargetTransformInfo::SK_PermuteTwoSrc, TrunkType,
3025+ TrunkType, NewMask, CostKind);
3026+
3027+ if (LocalNewCost >= NewCost && LocalOldCost < LocalNewCost - NewCost)
3028+ break ;
3029+
3030+ LLVM_DEBUG ({
3031+ if (ChainLength == 1 ) {
3032+ dbgs () << " Found chain of shuffles fed by length-changing shuffles: "
3033+ << I << ' \n ' ;
3034+ }
3035+ dbgs () << " next chain link: " << *Trunk << ' \n '
3036+ << " old cost: " << (OldCost + LocalOldCost)
3037+ << " new cost: " << LocalNewCost << ' \n ' ;
3038+ });
3039+
3040+ Mask = NewMask;
3041+ YMask = NewYMask;
3042+ OldCost += LocalOldCost;
3043+ NewCost = LocalNewCost;
3044+ Trunk = OuterV0;
3045+ ChainLength++;
3046+ }
3047+ if (ChainLength <= 1 )
3048+ return false ;
3049+
3050+ if (llvm::all_of (Mask, [&](int M) {
3051+ return M < 0 || M >= static_cast <int >(NumTrunkElts);
3052+ })) {
3053+ // Produce a canonical simplified form if all elements are sourced from Y.
3054+ for (int &M : Mask) {
3055+ if (M >= static_cast <int >(NumTrunkElts))
3056+ M = YMask[M - NumTrunkElts];
3057+ }
3058+ Value *Root =
3059+ Builder.CreateShuffleVector (Y, PoisonValue::get (Y->getType ()), Mask);
3060+ replaceValue (I, *Root);
3061+ return true ;
3062+ }
3063+
3064+ Value *Leaf =
3065+ Builder.CreateShuffleVector (Y, PoisonValue::get (Y->getType ()), YMask);
3066+ Value *Root = Builder.CreateShuffleVector (Trunk, Leaf, Mask);
3067+ replaceValue (I, *Root);
3068+ return true ;
3069+ }
3070+
28813071// / Try to convert
28823072// / "shuffle (intrinsic), (intrinsic)" into "intrinsic (shuffle), (shuffle)".
28833073bool VectorCombine::foldShuffleOfIntrinsics (Instruction &I) {
@@ -4799,6 +4989,8 @@ bool VectorCombine::run() {
47994989 return true ;
48004990 if (foldPermuteOfIntrinsic (I))
48014991 return true ;
4992+ if (foldShufflesOfLengthChangingShuffles (I))
4993+ return true ;
48024994 if (foldShuffleOfIntrinsics (I))
48034995 return true ;
48044996 if (foldSelectShuffle (I))
0 commit comments