Skip to content

Commit 7b03d19

Browse files
committed
VectorCombine: Fold chains of shuffles fed by length-changing shuffles
Such chains can arise from folding insert/extract chains. commit-id:a960175d
1 parent 5c15f57 commit 7b03d19

File tree

2 files changed

+179
-33
lines changed

2 files changed

+179
-33
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ class VectorCombine {
139139
bool foldShuffleOfSelects(Instruction &I);
140140
bool foldShuffleOfCastops(Instruction &I);
141141
bool foldShuffleOfShuffles(Instruction &I);
142+
bool foldShufflesOfLengthChangingShuffles(Instruction &I);
142143
bool foldShuffleOfIntrinsics(Instruction &I);
143144
bool foldShuffleToIdentity(Instruction &I);
144145
bool foldShuffleFromReductions(Instruction &I);
@@ -2877,6 +2878,174 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
28772878
return true;
28782879
}
28792880

2881+
/// Try to convert a chain of length-preserving shuffles that are fed by
2882+
/// length-changing shuffles from the same source, e.g. a chain of length 3:
2883+
///
2884+
/// "shuffle (shuffle (shuffle x, (shuffle y, undef)),
2885+
/// (shuffle y, undef)),
2886+
// (shuffle y, undef)"
2887+
///
2888+
/// into a single shuffle fed by a length-changing shuffle:
2889+
///
2890+
/// "shuffle x, (shuffle y, undef)"
2891+
///
2892+
/// Such chains arise e.g. from folding extract/insert sequences.
2893+
bool VectorCombine::foldShufflesOfLengthChangingShuffles(Instruction &I) {
2894+
FixedVectorType *TrunkType = dyn_cast<FixedVectorType>(I.getType());
2895+
if (!TrunkType)
2896+
return false;
2897+
2898+
unsigned ChainLength = 0;
2899+
SmallVector<int> Mask;
2900+
SmallVector<int> YMask;
2901+
InstructionCost OldCost = 0;
2902+
InstructionCost NewCost = 0;
2903+
Value *Trunk = &I;
2904+
unsigned NumTrunkElts = TrunkType->getNumElements();
2905+
FixedVectorType *YType = nullptr;
2906+
Value *Y = nullptr;
2907+
2908+
for (;;) {
2909+
// Match the current trunk against (commutations of) the pattern
2910+
// "shuffle trunk', (shuffle y, undef)"
2911+
ArrayRef<int> OuterMask;
2912+
Value *OuterV0, *OuterV1;
2913+
if (ChainLength != 0 && !Trunk->hasOneUse())
2914+
break;
2915+
if (!match(Trunk, m_Shuffle(m_Value(OuterV0), m_Value(OuterV1),
2916+
m_Mask(OuterMask))))
2917+
break;
2918+
if (OuterV0->getType() != TrunkType) {
2919+
// This shuffle is not length-preserving, so it cannot be part of the
2920+
// chain.
2921+
break;
2922+
}
2923+
2924+
ArrayRef<int> InnerMask0, InnerMask1;
2925+
Value *A0, *A1, *B0, *B1;
2926+
bool Match0 =
2927+
match(OuterV0, m_Shuffle(m_Value(A0), m_Value(B0), m_Mask(InnerMask0)));
2928+
bool Match1 =
2929+
match(OuterV1, m_Shuffle(m_Value(A1), m_Value(B1), m_Mask(InnerMask1)));
2930+
bool Match0Leaf = Match0 && A0->getType() != I.getType();
2931+
bool Match1Leaf = Match1 && A1->getType() != I.getType();
2932+
if (Match0Leaf == Match1Leaf) {
2933+
// Only handle the case of exactly one leaf in each step. The "two leaves"
2934+
// case is handled by foldShuffleOfShuffles.
2935+
break;
2936+
}
2937+
2938+
SmallVector<int> CommutedOuterMask;
2939+
if (Match0Leaf) {
2940+
std::swap(OuterV0, OuterV1);
2941+
std::swap(InnerMask0, InnerMask1);
2942+
std::swap(A0, A1);
2943+
std::swap(B0, B1);
2944+
llvm::append_range(CommutedOuterMask, OuterMask);
2945+
for (int &M : CommutedOuterMask) {
2946+
if (M == PoisonMaskElem)
2947+
continue;
2948+
if (M < (int)NumTrunkElts)
2949+
M += NumTrunkElts;
2950+
else
2951+
M -= NumTrunkElts;
2952+
}
2953+
OuterMask = CommutedOuterMask;
2954+
}
2955+
if (!OuterV1->hasOneUse())
2956+
break;
2957+
2958+
if (!isa<UndefValue>(A1)) {
2959+
if (!Y)
2960+
Y = A1;
2961+
else if (Y != A1)
2962+
break;
2963+
}
2964+
if (!isa<UndefValue>(B1)) {
2965+
if (!Y)
2966+
Y = B1;
2967+
else if (Y != B1)
2968+
break;
2969+
}
2970+
2971+
InstructionCost LocalOldCost =
2972+
TTI.getInstructionCost(cast<User>(Trunk), CostKind) +
2973+
TTI.getInstructionCost(cast<User>(OuterV1), CostKind);
2974+
2975+
// Handle the initial (start of chain) case.
2976+
if (!ChainLength) {
2977+
YType = cast<FixedVectorType>(A1->getType());
2978+
Mask.assign(OuterMask);
2979+
YMask.assign(InnerMask1);
2980+
OldCost = NewCost = LocalOldCost;
2981+
Trunk = OuterV0;
2982+
ChainLength++;
2983+
continue;
2984+
}
2985+
2986+
// For the non-root case, first attempt to combine masks.
2987+
SmallVector<int> NewYMask(YMask);
2988+
bool Valid = true;
2989+
for (auto [CombinedM, LeafM] : llvm::zip(NewYMask, InnerMask1)) {
2990+
if (LeafM == -1 || CombinedM == LeafM)
2991+
continue;
2992+
if (CombinedM == -1) {
2993+
CombinedM = LeafM;
2994+
} else {
2995+
Valid = false;
2996+
break;
2997+
}
2998+
}
2999+
if (!Valid)
3000+
break;
3001+
3002+
SmallVector<int> NewMask;
3003+
NewMask.reserve(NumTrunkElts);
3004+
for (int M : Mask) {
3005+
if (M < 0 || M >= static_cast<int>(NumTrunkElts))
3006+
NewMask.push_back(M);
3007+
else
3008+
NewMask.push_back(OuterMask[M]);
3009+
}
3010+
3011+
// Break the chain if adding this new step complicates the shuffles such
3012+
// that it would increase the new cost by more than the old cost of this
3013+
// step.
3014+
InstructionCost LocalNewCost =
3015+
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, TrunkType,
3016+
YType, NewYMask, CostKind) +
3017+
TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, TrunkType,
3018+
TrunkType, NewMask, CostKind);
3019+
3020+
if (LocalNewCost >= NewCost && LocalOldCost < LocalNewCost - NewCost)
3021+
break;
3022+
3023+
LLVM_DEBUG({
3024+
if (ChainLength == 1) {
3025+
dbgs() << "Found chain of shuffles fed by length-changing shuffles: "
3026+
<< I << '\n';
3027+
}
3028+
dbgs() << " next chain link: " << *Trunk << '\n'
3029+
<< " old cost: " << (OldCost + LocalOldCost)
3030+
<< " new cost: " << LocalNewCost << '\n';
3031+
});
3032+
3033+
Mask = NewMask;
3034+
YMask = NewYMask;
3035+
OldCost += LocalOldCost;
3036+
NewCost = LocalNewCost;
3037+
Trunk = OuterV0;
3038+
ChainLength++;
3039+
}
3040+
if (ChainLength <= 1)
3041+
return false;
3042+
3043+
Value *Leaf = Builder.CreateShuffleVector(Y, PoisonValue::get(YType), YMask);
3044+
Value *Root = Builder.CreateShuffleVector(Trunk, Leaf, Mask);
3045+
replaceValue(I, *Root);
3046+
return true;
3047+
}
3048+
28803049
/// Try to convert
28813050
/// "shuffle (intrinsic), (intrinsic)" into "intrinsic (shuffle), (shuffle)".
28823051
bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) {
@@ -4718,6 +4887,8 @@ bool VectorCombine::run() {
47184887
return true;
47194888
if (foldShuffleOfShuffles(I))
47204889
return true;
4890+
if (foldShufflesOfLengthChangingShuffles(I))
4891+
return true;
47214892
if (foldShuffleOfIntrinsics(I))
47224893
return true;
47234894
if (foldSelectShuffle(I))

0 commit comments

Comments
 (0)