Skip to content

Commit aaee5f6

Browse files
committed
VectorCombine: Fold chains of shuffles fed by length-changing shuffles
Such chains can arise from folding insert/extract chains. commit-id:a960175d
1 parent aa6362b commit aaee5f6

File tree

2 files changed

+176
-33
lines changed

2 files changed

+176
-33
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ class VectorCombine {
139139
bool foldShuffleOfSelects(Instruction &I);
140140
bool foldShuffleOfCastops(Instruction &I);
141141
bool foldShuffleOfShuffles(Instruction &I);
142+
bool foldShufflesOfLengthChangingShuffles(Instruction &I);
142143
bool foldShuffleOfIntrinsics(Instruction &I);
143144
bool foldShuffleToIdentity(Instruction &I);
144145
bool foldShuffleFromReductions(Instruction &I);
@@ -2877,6 +2878,171 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
28772878
return true;
28782879
}
28792880

2881+
/// Try to convert a chain of length-preserving shuffles that are fed by
2882+
/// length-changing shuffles from the same source, e.g. a chain of length 3:
2883+
///
2884+
/// "shuffle (shuffle (shuffle x, (shuffle y, undef)),
2885+
/// (shuffle y, undef)),
2886+
// (shuffle y, undef)"
2887+
///
2888+
/// into a single shuffle fed by a length-changing shuffle:
2889+
///
2890+
/// "shuffle x, (shuffle y, undef)"
2891+
///
2892+
/// Such chains arise e.g. from folding extract/insert sequences.
2893+
bool VectorCombine::foldShufflesOfLengthChangingShuffles(Instruction &I) {
2894+
unsigned ChainLength = 0;
2895+
SmallVector<int> Mask;
2896+
SmallVector<int> YMask;
2897+
InstructionCost OldCost = 0;
2898+
InstructionCost NewCost = 0;
2899+
FixedVectorType *TrunkType = cast<FixedVectorType>(I.getType());
2900+
Value *Trunk = &I;
2901+
unsigned NumTrunkElts = TrunkType->getNumElements();
2902+
FixedVectorType *YType = nullptr;
2903+
Value *Y = nullptr;
2904+
2905+
for (;;) {
2906+
// Match the current trunk against (commutations of) the pattern
2907+
// "shuffle trunk', (shuffle y, undef)"
2908+
ArrayRef<int> OuterMask;
2909+
Value *OuterV0, *OuterV1;
2910+
if (ChainLength != 0 && !Trunk->hasOneUse())
2911+
break;
2912+
if (!match(Trunk, m_Shuffle(m_Value(OuterV0), m_Value(OuterV1),
2913+
m_Mask(OuterMask))))
2914+
break;
2915+
if (OuterV0->getType() != TrunkType) {
2916+
// This shuffle is not length-preserving, so it cannot be part of the
2917+
// chain.
2918+
break;
2919+
}
2920+
2921+
ArrayRef<int> InnerMask0, InnerMask1;
2922+
Value *A0, *A1, *B0, *B1;
2923+
bool Match0 =
2924+
match(OuterV0, m_Shuffle(m_Value(A0), m_Value(B0), m_Mask(InnerMask0)));
2925+
bool Match1 =
2926+
match(OuterV1, m_Shuffle(m_Value(A1), m_Value(B1), m_Mask(InnerMask1)));
2927+
bool Match0Leaf = Match0 && A0->getType() != I.getType();
2928+
bool Match1Leaf = Match1 && A1->getType() != I.getType();
2929+
if (Match0Leaf == Match1Leaf) {
2930+
// Only handle the case of exactly one leaf in each step. The "two leaves"
2931+
// case is handled by foldShuffleOfShuffles.
2932+
break;
2933+
}
2934+
2935+
SmallVector<int> CommutedOuterMask;
2936+
if (Match0Leaf) {
2937+
std::swap(OuterV0, OuterV1);
2938+
std::swap(InnerMask0, InnerMask1);
2939+
std::swap(A0, A1);
2940+
std::swap(B0, B1);
2941+
llvm::append_range(CommutedOuterMask, OuterMask);
2942+
for (int &M : CommutedOuterMask) {
2943+
if (M == PoisonMaskElem)
2944+
continue;
2945+
if (M < (int)NumTrunkElts)
2946+
M += NumTrunkElts;
2947+
else
2948+
M -= NumTrunkElts;
2949+
}
2950+
OuterMask = CommutedOuterMask;
2951+
}
2952+
if (!OuterV1->hasOneUse())
2953+
break;
2954+
2955+
if (!isa<PoisonValue>(A1)) {
2956+
if (!Y)
2957+
Y = A1;
2958+
else if (Y != A1)
2959+
break;
2960+
}
2961+
if (!isa<PoisonValue>(B1)) {
2962+
if (!Y)
2963+
Y = B1;
2964+
else if (Y != B1)
2965+
break;
2966+
}
2967+
2968+
InstructionCost LocalOldCost =
2969+
TTI.getInstructionCost(cast<User>(Trunk), CostKind) +
2970+
TTI.getInstructionCost(cast<User>(OuterV1), CostKind);
2971+
2972+
// Handle the initial (start of chain) case.
2973+
if (!ChainLength) {
2974+
YType = cast<FixedVectorType>(A1->getType());
2975+
Mask.assign(OuterMask);
2976+
YMask.assign(InnerMask1);
2977+
OldCost = NewCost = LocalOldCost;
2978+
Trunk = OuterV0;
2979+
ChainLength++;
2980+
continue;
2981+
}
2982+
2983+
// For the non-root case, first attempt to combine masks.
2984+
SmallVector<int> NewYMask(YMask);
2985+
bool Valid = true;
2986+
for (auto [CombinedM, LeafM] : llvm::zip(NewYMask, InnerMask1)) {
2987+
if (LeafM == -1 || CombinedM == LeafM)
2988+
continue;
2989+
if (CombinedM == -1) {
2990+
CombinedM = LeafM;
2991+
} else {
2992+
Valid = false;
2993+
break;
2994+
}
2995+
}
2996+
if (!Valid)
2997+
break;
2998+
2999+
SmallVector<int> NewMask;
3000+
NewMask.reserve(NumTrunkElts);
3001+
for (int M : Mask) {
3002+
if (M < 0 || M >= (int)NumTrunkElts)
3003+
NewMask.push_back(M);
3004+
else
3005+
NewMask.push_back(OuterMask[M]);
3006+
}
3007+
3008+
// Break the chain if adding this new step complicates the shuffles such
3009+
// that it would increase the new cost by more than the old cost of this
3010+
// step.
3011+
InstructionCost LocalNewCost = 0;
3012+
LocalNewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
3013+
TrunkType, YType, NewYMask, CostKind);
3014+
LocalNewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc,
3015+
TrunkType, TrunkType, NewMask, CostKind);
3016+
3017+
if (LocalNewCost >= NewCost && LocalOldCost < LocalNewCost - NewCost)
3018+
break;
3019+
3020+
LLVM_DEBUG({
3021+
if (ChainLength == 1) {
3022+
dbgs() << "Found chain of shuffles fed by length-changing shuffles: "
3023+
<< I << '\n';
3024+
}
3025+
dbgs() << " next chain link: " << *Trunk << '\n'
3026+
<< " old cost: " << (OldCost + LocalOldCost)
3027+
<< " new cost: " << LocalNewCost << '\n';
3028+
});
3029+
3030+
Mask = NewMask;
3031+
YMask = NewYMask;
3032+
OldCost += LocalOldCost;
3033+
NewCost = LocalNewCost;
3034+
Trunk = OuterV0;
3035+
ChainLength++;
3036+
}
3037+
if (ChainLength <= 1)
3038+
return false;
3039+
3040+
Value *Leaf = Builder.CreateShuffleVector(Y, PoisonValue::get(YType), YMask);
3041+
Value *Root = Builder.CreateShuffleVector(Trunk, Leaf, Mask);
3042+
replaceValue(I, *Root);
3043+
return true;
3044+
}
3045+
28803046
/// Try to convert
28813047
/// "shuffle (intrinsic), (intrinsic)" into "intrinsic (shuffle), (shuffle)".
28823048
bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) {
@@ -4718,6 +4884,8 @@ bool VectorCombine::run() {
47184884
return true;
47194885
if (foldShuffleOfShuffles(I))
47204886
return true;
4887+
if (foldShufflesOfLengthChangingShuffles(I))
4888+
return true;
47214889
if (foldShuffleOfIntrinsics(I))
47224890
return true;
47234891
if (foldSelectShuffle(I))

0 commit comments

Comments
 (0)