Skip to content

Commit 1c241c1

Browse files
committed
VectorCombine: Fold chains of shuffles fed by length-changing shuffles
Such chains can arise from folding insert/extract chains. commit-id:a960175d
1 parent 2466adf commit 1c241c1

File tree

3 files changed

+200
-40
lines changed

3 files changed

+200
-40
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ class VectorCombine {
139139
bool foldShuffleOfSelects(Instruction &I);
140140
bool foldShuffleOfCastops(Instruction &I);
141141
bool foldShuffleOfShuffles(Instruction &I);
142+
bool foldShufflesOfLengthChangingShuffles(Instruction &I);
142143
bool foldShuffleOfIntrinsics(Instruction &I);
143144
bool foldShuffleToIdentity(Instruction &I);
144145
bool foldShuffleFromReductions(Instruction &I);
@@ -2877,6 +2878,195 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
28772878
return true;
28782879
}
28792880

2881+
/// Try to convert a chain of length-preserving shuffles that are fed by
2882+
/// length-changing shuffles from the same source, e.g. a chain of length 3:
2883+
///
2884+
/// "shuffle (shuffle (shuffle x, (shuffle y, undef)),
2885+
/// (shuffle y, undef)),
2886+
// (shuffle y, undef)"
2887+
///
2888+
/// into a single shuffle fed by a length-changing shuffle:
2889+
///
2890+
/// "shuffle x, (shuffle y, undef)"
2891+
///
2892+
/// Such chains arise e.g. from folding extract/insert sequences.
2893+
bool VectorCombine::foldShufflesOfLengthChangingShuffles(Instruction &I) {
2894+
FixedVectorType *TrunkType = dyn_cast<FixedVectorType>(I.getType());
2895+
if (!TrunkType)
2896+
return false;
2897+
2898+
unsigned ChainLength = 0;
2899+
SmallVector<int> Mask;
2900+
SmallVector<int> YMask;
2901+
InstructionCost OldCost = 0;
2902+
InstructionCost NewCost = 0;
2903+
Value *Trunk = &I;
2904+
unsigned NumTrunkElts = TrunkType->getNumElements();
2905+
Value *Y = nullptr;
2906+
2907+
for (;;) {
2908+
// Match the current trunk against (commutations of) the pattern
2909+
// "shuffle trunk', (shuffle y, undef)"
2910+
ArrayRef<int> OuterMask;
2911+
Value *OuterV0, *OuterV1;
2912+
if (ChainLength != 0 && !Trunk->hasOneUse())
2913+
break;
2914+
if (!match(Trunk, m_Shuffle(m_Value(OuterV0), m_Value(OuterV1),
2915+
m_Mask(OuterMask))))
2916+
break;
2917+
if (OuterV0->getType() != TrunkType) {
2918+
// This shuffle is not length-preserving, so it cannot be part of the
2919+
// chain.
2920+
break;
2921+
}
2922+
2923+
ArrayRef<int> InnerMask0, InnerMask1;
2924+
Value *A0, *A1, *B0, *B1;
2925+
bool Match0 =
2926+
match(OuterV0, m_Shuffle(m_Value(A0), m_Value(B0), m_Mask(InnerMask0)));
2927+
bool Match1 =
2928+
match(OuterV1, m_Shuffle(m_Value(A1), m_Value(B1), m_Mask(InnerMask1)));
2929+
bool Match0Leaf = Match0 && A0->getType() != I.getType();
2930+
bool Match1Leaf = Match1 && A1->getType() != I.getType();
2931+
if (Match0Leaf == Match1Leaf) {
2932+
// Only handle the case of exactly one leaf in each step. The "two leaves"
2933+
// case is handled by foldShuffleOfShuffles.
2934+
break;
2935+
}
2936+
2937+
SmallVector<int> CommutedOuterMask;
2938+
if (Match0Leaf) {
2939+
std::swap(OuterV0, OuterV1);
2940+
std::swap(InnerMask0, InnerMask1);
2941+
std::swap(A0, A1);
2942+
std::swap(B0, B1);
2943+
llvm::append_range(CommutedOuterMask, OuterMask);
2944+
for (int &M : CommutedOuterMask) {
2945+
if (M == PoisonMaskElem)
2946+
continue;
2947+
if (M < (int)NumTrunkElts)
2948+
M += NumTrunkElts;
2949+
else
2950+
M -= NumTrunkElts;
2951+
}
2952+
OuterMask = CommutedOuterMask;
2953+
}
2954+
if (!OuterV1->hasOneUse())
2955+
break;
2956+
2957+
if (!isa<UndefValue>(A1)) {
2958+
if (!Y)
2959+
Y = A1;
2960+
else if (Y != A1)
2961+
break;
2962+
}
2963+
if (!isa<UndefValue>(B1)) {
2964+
if (!Y)
2965+
Y = B1;
2966+
else if (Y != B1)
2967+
break;
2968+
}
2969+
2970+
auto *YType = cast<FixedVectorType>(A1->getType());
2971+
int NumLeafElts = YType->getNumElements();
2972+
SmallVector<int> LocalYMask(InnerMask1);
2973+
for (int &M : LocalYMask) {
2974+
if (M >= NumLeafElts)
2975+
M -= NumLeafElts;
2976+
}
2977+
2978+
InstructionCost LocalOldCost =
2979+
TTI.getInstructionCost(cast<User>(Trunk), CostKind) +
2980+
TTI.getInstructionCost(cast<User>(OuterV1), CostKind);
2981+
2982+
// Handle the initial (start of chain) case.
2983+
if (!ChainLength) {
2984+
Mask.assign(OuterMask);
2985+
YMask.assign(LocalYMask);
2986+
OldCost = NewCost = LocalOldCost;
2987+
Trunk = OuterV0;
2988+
ChainLength++;
2989+
continue;
2990+
}
2991+
2992+
// For the non-root case, first attempt to combine masks.
2993+
SmallVector<int> NewYMask(YMask);
2994+
bool Valid = true;
2995+
for (auto [CombinedM, LeafM] : llvm::zip(NewYMask, LocalYMask)) {
2996+
if (LeafM == -1 || CombinedM == LeafM)
2997+
continue;
2998+
if (CombinedM == -1) {
2999+
CombinedM = LeafM;
3000+
} else {
3001+
Valid = false;
3002+
break;
3003+
}
3004+
}
3005+
if (!Valid)
3006+
break;
3007+
3008+
SmallVector<int> NewMask;
3009+
NewMask.reserve(NumTrunkElts);
3010+
for (int M : Mask) {
3011+
if (M < 0 || M >= static_cast<int>(NumTrunkElts))
3012+
NewMask.push_back(M);
3013+
else
3014+
NewMask.push_back(OuterMask[M]);
3015+
}
3016+
3017+
// Break the chain if adding this new step complicates the shuffles such
3018+
// that it would increase the new cost by more than the old cost of this
3019+
// step.
3020+
InstructionCost LocalNewCost =
3021+
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, TrunkType,
3022+
YType, NewYMask, CostKind) +
3023+
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, TrunkType,
3024+
TrunkType, NewMask, CostKind);
3025+
3026+
if (LocalNewCost >= NewCost && LocalOldCost < LocalNewCost - NewCost)
3027+
break;
3028+
3029+
LLVM_DEBUG({
3030+
if (ChainLength == 1) {
3031+
dbgs() << "Found chain of shuffles fed by length-changing shuffles: "
3032+
<< I << '\n';
3033+
}
3034+
dbgs() << " next chain link: " << *Trunk << '\n'
3035+
<< " old cost: " << (OldCost + LocalOldCost)
3036+
<< " new cost: " << LocalNewCost << '\n';
3037+
});
3038+
3039+
Mask = NewMask;
3040+
YMask = NewYMask;
3041+
OldCost += LocalOldCost;
3042+
NewCost = LocalNewCost;
3043+
Trunk = OuterV0;
3044+
ChainLength++;
3045+
}
3046+
if (ChainLength <= 1)
3047+
return false;
3048+
3049+
if (llvm::all_of(Mask, [&](int M) {
3050+
return M < 0 || M >= static_cast<int>(NumTrunkElts);
3051+
})) {
3052+
// Produce a canonical simplified form if all elements are sourced from Y.
3053+
for (int &M : Mask) {
3054+
if (M >= static_cast<int>(NumTrunkElts))
3055+
M = YMask[M - NumTrunkElts];
3056+
}
3057+
Value *Root =
3058+
Builder.CreateShuffleVector(Y, PoisonValue::get(Y->getType()), Mask);
3059+
replaceValue(I, *Root);
3060+
return true;
3061+
}
3062+
3063+
Value *Leaf =
3064+
Builder.CreateShuffleVector(Y, PoisonValue::get(Y->getType()), YMask);
3065+
Value *Root = Builder.CreateShuffleVector(Trunk, Leaf, Mask);
3066+
replaceValue(I, *Root);
3067+
return true;
3068+
}
3069+
28803070
/// Try to convert
28813071
/// "shuffle (intrinsic), (intrinsic)" into "intrinsic (shuffle), (shuffle)".
28823072
bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) {
@@ -4718,6 +4908,8 @@ bool VectorCombine::run() {
47184908
return true;
47194909
if (foldShuffleOfShuffles(I))
47204910
return true;
4911+
if (foldShufflesOfLengthChangingShuffles(I))
4912+
return true;
47214913
if (foldShuffleOfIntrinsics(I))
47224914
return true;
47234915
if (foldSelectShuffle(I))

0 commit comments

Comments
 (0)