Skip to content

Commit 316715e

Browse files
committed
VectorCombine: Fold chains of shuffles fed by length-changing shuffles
Such chains can arise from folding insert/extract chains. commit-id:a960175d
1 parent 1e44d4b commit 316715e

File tree

3 files changed

+200
-40
lines changed

3 files changed

+200
-40
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ class VectorCombine {
140140
bool foldShuffleOfCastops(Instruction &I);
141141
bool foldShuffleOfShuffles(Instruction &I);
142142
bool foldPermuteOfIntrinsic(Instruction &I);
143+
bool foldShufflesOfLengthChangingShuffles(Instruction &I);
143144
bool foldShuffleOfIntrinsics(Instruction &I);
144145
bool foldShuffleToIdentity(Instruction &I);
145146
bool foldShuffleFromReductions(Instruction &I);
@@ -2878,6 +2879,195 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
28782879
return true;
28792880
}
28802881

2882+
/// Try to convert a chain of length-preserving shuffles that are fed by
2883+
/// length-changing shuffles from the same source, e.g. a chain of length 3:
2884+
///
2885+
/// "shuffle (shuffle (shuffle x, (shuffle y, undef)),
2886+
/// (shuffle y, undef)),
2887+
// (shuffle y, undef)"
2888+
///
2889+
/// into a single shuffle fed by a length-changing shuffle:
2890+
///
2891+
/// "shuffle x, (shuffle y, undef)"
2892+
///
2893+
/// Such chains arise e.g. from folding extract/insert sequences.
2894+
bool VectorCombine::foldShufflesOfLengthChangingShuffles(Instruction &I) {
2895+
FixedVectorType *TrunkType = dyn_cast<FixedVectorType>(I.getType());
2896+
if (!TrunkType)
2897+
return false;
2898+
2899+
unsigned ChainLength = 0;
2900+
SmallVector<int> Mask;
2901+
SmallVector<int> YMask;
2902+
InstructionCost OldCost = 0;
2903+
InstructionCost NewCost = 0;
2904+
Value *Trunk = &I;
2905+
unsigned NumTrunkElts = TrunkType->getNumElements();
2906+
Value *Y = nullptr;
2907+
2908+
for (;;) {
2909+
// Match the current trunk against (commutations of) the pattern
2910+
// "shuffle trunk', (shuffle y, undef)"
2911+
ArrayRef<int> OuterMask;
2912+
Value *OuterV0, *OuterV1;
2913+
if (ChainLength != 0 && !Trunk->hasOneUse())
2914+
break;
2915+
if (!match(Trunk, m_Shuffle(m_Value(OuterV0), m_Value(OuterV1),
2916+
m_Mask(OuterMask))))
2917+
break;
2918+
if (OuterV0->getType() != TrunkType) {
2919+
// This shuffle is not length-preserving, so it cannot be part of the
2920+
// chain.
2921+
break;
2922+
}
2923+
2924+
ArrayRef<int> InnerMask0, InnerMask1;
2925+
Value *A0, *A1, *B0, *B1;
2926+
bool Match0 =
2927+
match(OuterV0, m_Shuffle(m_Value(A0), m_Value(B0), m_Mask(InnerMask0)));
2928+
bool Match1 =
2929+
match(OuterV1, m_Shuffle(m_Value(A1), m_Value(B1), m_Mask(InnerMask1)));
2930+
bool Match0Leaf = Match0 && A0->getType() != I.getType();
2931+
bool Match1Leaf = Match1 && A1->getType() != I.getType();
2932+
if (Match0Leaf == Match1Leaf) {
2933+
// Only handle the case of exactly one leaf in each step. The "two leaves"
2934+
// case is handled by foldShuffleOfShuffles.
2935+
break;
2936+
}
2937+
2938+
SmallVector<int> CommutedOuterMask;
2939+
if (Match0Leaf) {
2940+
std::swap(OuterV0, OuterV1);
2941+
std::swap(InnerMask0, InnerMask1);
2942+
std::swap(A0, A1);
2943+
std::swap(B0, B1);
2944+
llvm::append_range(CommutedOuterMask, OuterMask);
2945+
for (int &M : CommutedOuterMask) {
2946+
if (M == PoisonMaskElem)
2947+
continue;
2948+
if (M < (int)NumTrunkElts)
2949+
M += NumTrunkElts;
2950+
else
2951+
M -= NumTrunkElts;
2952+
}
2953+
OuterMask = CommutedOuterMask;
2954+
}
2955+
if (!OuterV1->hasOneUse())
2956+
break;
2957+
2958+
if (!isa<UndefValue>(A1)) {
2959+
if (!Y)
2960+
Y = A1;
2961+
else if (Y != A1)
2962+
break;
2963+
}
2964+
if (!isa<UndefValue>(B1)) {
2965+
if (!Y)
2966+
Y = B1;
2967+
else if (Y != B1)
2968+
break;
2969+
}
2970+
2971+
auto *YType = cast<FixedVectorType>(A1->getType());
2972+
int NumLeafElts = YType->getNumElements();
2973+
SmallVector<int> LocalYMask(InnerMask1);
2974+
for (int &M : LocalYMask) {
2975+
if (M >= NumLeafElts)
2976+
M -= NumLeafElts;
2977+
}
2978+
2979+
InstructionCost LocalOldCost =
2980+
TTI.getInstructionCost(cast<User>(Trunk), CostKind) +
2981+
TTI.getInstructionCost(cast<User>(OuterV1), CostKind);
2982+
2983+
// Handle the initial (start of chain) case.
2984+
if (!ChainLength) {
2985+
Mask.assign(OuterMask);
2986+
YMask.assign(LocalYMask);
2987+
OldCost = NewCost = LocalOldCost;
2988+
Trunk = OuterV0;
2989+
ChainLength++;
2990+
continue;
2991+
}
2992+
2993+
// For the non-root case, first attempt to combine masks.
2994+
SmallVector<int> NewYMask(YMask);
2995+
bool Valid = true;
2996+
for (auto [CombinedM, LeafM] : llvm::zip(NewYMask, LocalYMask)) {
2997+
if (LeafM == -1 || CombinedM == LeafM)
2998+
continue;
2999+
if (CombinedM == -1) {
3000+
CombinedM = LeafM;
3001+
} else {
3002+
Valid = false;
3003+
break;
3004+
}
3005+
}
3006+
if (!Valid)
3007+
break;
3008+
3009+
SmallVector<int> NewMask;
3010+
NewMask.reserve(NumTrunkElts);
3011+
for (int M : Mask) {
3012+
if (M < 0 || M >= static_cast<int>(NumTrunkElts))
3013+
NewMask.push_back(M);
3014+
else
3015+
NewMask.push_back(OuterMask[M]);
3016+
}
3017+
3018+
// Break the chain if adding this new step complicates the shuffles such
3019+
// that it would increase the new cost by more than the old cost of this
3020+
// step.
3021+
InstructionCost LocalNewCost =
3022+
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, TrunkType,
3023+
YType, NewYMask, CostKind) +
3024+
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, TrunkType,
3025+
TrunkType, NewMask, CostKind);
3026+
3027+
if (LocalNewCost >= NewCost && LocalOldCost < LocalNewCost - NewCost)
3028+
break;
3029+
3030+
LLVM_DEBUG({
3031+
if (ChainLength == 1) {
3032+
dbgs() << "Found chain of shuffles fed by length-changing shuffles: "
3033+
<< I << '\n';
3034+
}
3035+
dbgs() << " next chain link: " << *Trunk << '\n'
3036+
<< " old cost: " << (OldCost + LocalOldCost)
3037+
<< " new cost: " << LocalNewCost << '\n';
3038+
});
3039+
3040+
Mask = NewMask;
3041+
YMask = NewYMask;
3042+
OldCost += LocalOldCost;
3043+
NewCost = LocalNewCost;
3044+
Trunk = OuterV0;
3045+
ChainLength++;
3046+
}
3047+
if (ChainLength <= 1)
3048+
return false;
3049+
3050+
if (llvm::all_of(Mask, [&](int M) {
3051+
return M < 0 || M >= static_cast<int>(NumTrunkElts);
3052+
})) {
3053+
// Produce a canonical simplified form if all elements are sourced from Y.
3054+
for (int &M : Mask) {
3055+
if (M >= static_cast<int>(NumTrunkElts))
3056+
M = YMask[M - NumTrunkElts];
3057+
}
3058+
Value *Root =
3059+
Builder.CreateShuffleVector(Y, PoisonValue::get(Y->getType()), Mask);
3060+
replaceValue(I, *Root);
3061+
return true;
3062+
}
3063+
3064+
Value *Leaf =
3065+
Builder.CreateShuffleVector(Y, PoisonValue::get(Y->getType()), YMask);
3066+
Value *Root = Builder.CreateShuffleVector(Trunk, Leaf, Mask);
3067+
replaceValue(I, *Root);
3068+
return true;
3069+
}
3070+
28813071
/// Try to convert
28823072
/// "shuffle (intrinsic), (intrinsic)" into "intrinsic (shuffle), (shuffle)".
28833073
bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) {
@@ -4799,6 +4989,8 @@ bool VectorCombine::run() {
47994989
return true;
48004990
if (foldPermuteOfIntrinsic(I))
48014991
return true;
4992+
if (foldShufflesOfLengthChangingShuffles(I))
4993+
return true;
48024994
if (foldShuffleOfIntrinsics(I))
48034995
return true;
48044996
if (foldSelectShuffle(I))

0 commit comments

Comments
 (0)