@@ -12613,11 +12613,13 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
1261312613 }
1261412614 InstructionCost createFreeze(InstructionCost Cost) { return Cost; }
1261512615 /// Finalize emission of the shuffles.
12616- InstructionCost
12617- finalize(ArrayRef<int> ExtMask,
12618- ArrayRef<std::pair<const TreeEntry *, unsigned>> SubVectors,
12619- ArrayRef<int> SubVectorsMask, unsigned VF = 0,
12620- function_ref<void(Value *&, SmallVectorImpl<int> &)> Action = {}) {
12616+ InstructionCost finalize(
12617+ ArrayRef<int> ExtMask,
12618+ ArrayRef<std::pair<const TreeEntry *, unsigned>> SubVectors,
12619+ ArrayRef<int> SubVectorsMask, unsigned VF = 0,
12620+ function_ref<void(Value *&, SmallVectorImpl<int> &,
12621+ function_ref<Value *(Value *, Value *, ArrayRef<int>)>)>
12622+ Action = {}) {
1262112623 IsFinalized = true;
1262212624 if (Action) {
1262312625 const PointerUnion<Value *, const TreeEntry *> &Vec = InVectors.front();
@@ -12629,7 +12631,10 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
1262912631 assert(VF > 0 &&
1263012632 "Expected vector length for the final value before action.");
1263112633 Value *V = cast<Value *>(Vec);
12632- Action(V, CommonMask);
12634+ Action(V, CommonMask, [this](Value *V1, Value *V2, ArrayRef<int> Mask) {
12635+ Cost += createShuffle(V1, V2, Mask);
12636+ return V1;
12637+ });
1263312638 InVectors.front() = V;
1263412639 }
1263512640 if (!SubVectors.empty()) {
@@ -16592,11 +16597,13 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
1659216597 /// Finalize emission of the shuffles.
1659316598 /// \param Action the action (if any) to be performed before final applying of
1659416599 /// the \p ExtMask mask.
16595- Value *
16596- finalize(ArrayRef<int> ExtMask,
16597- ArrayRef<std::pair<const TreeEntry *, unsigned>> SubVectors,
16598- ArrayRef<int> SubVectorsMask, unsigned VF = 0,
16599- function_ref<void(Value *&, SmallVectorImpl<int> &)> Action = {}) {
16600+ Value *finalize(
16601+ ArrayRef<int> ExtMask,
16602+ ArrayRef<std::pair<const TreeEntry *, unsigned>> SubVectors,
16603+ ArrayRef<int> SubVectorsMask, unsigned VF = 0,
16604+ function_ref<void(Value *&, SmallVectorImpl<int> &,
16605+ function_ref<Value *(Value *, Value *, ArrayRef<int>)>)>
16606+ Action = {}) {
1660016607 IsFinalized = true;
1660116608 if (Action) {
1660216609 Value *Vec = InVectors.front();
@@ -16615,7 +16622,9 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
1661516622 std::iota(ResizeMask.begin(), std::next(ResizeMask.begin(), VecVF), 0);
1661616623 Vec = createShuffle(Vec, nullptr, ResizeMask);
1661716624 }
16618- Action(Vec, CommonMask);
16625+ Action(Vec, CommonMask, [this](Value *V1, Value *V2, ArrayRef<int> Mask) {
16626+ return createShuffle(V1, V2, Mask);
16627+ });
1661916628 InVectors.front() = Vec;
1662016629 }
1662116630 if (!SubVectors.empty()) {
@@ -17277,9 +17286,66 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Type *ScalarTy,
1727717286 else
1727817287 Res = ShuffleBuilder.finalize(
1727917288 E->ReuseShuffleIndices, SubVectors, SubVectorsMask, E->Scalars.size(),
17280- [&](Value *&Vec, SmallVectorImpl<int> &Mask) {
17281- TryPackScalars(NonConstants, Mask, /*IsRootPoison=*/false);
17282- Vec = ShuffleBuilder.gather(NonConstants, Mask.size(), Vec);
17289+ [&](Value *&Vec, SmallVectorImpl<int> &Mask, auto CreateShuffle) {
17290+ bool IsSplat = isSplat(NonConstants);
17291+ SmallVector<int> BVMask(Mask.size(), PoisonMaskElem);
17292+ TryPackScalars(NonConstants, BVMask, /*IsRootPoison=*/false);
17293+ auto CheckIfSplatIsProfitable = [&]() {
17294+ // Estimate the cost of splatting + shuffle and compare with
17295+ // insert + shuffle.
17296+ constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
17297+ Value *V = *find_if_not(NonConstants, IsaPred<UndefValue>);
17298+ if (isa<ExtractElementInst>(V) || isVectorized(V))
17299+ return false;
17300+ InstructionCost SplatCost = TTI->getVectorInstrCost(
17301+ Instruction::InsertElement, VecTy, CostKind, /*Index=*/0,
17302+ PoisonValue::get(VecTy), V);
17303+ SmallVector<int> NewMask(Mask.begin(), Mask.end());
17304+ for (auto [Idx, I] : enumerate(BVMask))
17305+ if (I != PoisonMaskElem)
17306+ NewMask[Idx] = Mask.size();
17307+ SplatCost += ::getShuffleCost(*TTI, TTI::SK_PermuteTwoSrc, VecTy,
17308+ NewMask, CostKind);
17309+ InstructionCost BVCost = TTI->getVectorInstrCost(
17310+ Instruction::InsertElement, VecTy, CostKind,
17311+ *find_if(Mask, [](int I) { return I != PoisonMaskElem; }),
17312+ Vec, V);
17313+ // Shuffle required?
17314+ if (count(BVMask, PoisonMaskElem) <
17315+ static_cast<int>(BVMask.size() - 1)) {
17316+ SmallVector<int> NewMask(Mask.begin(), Mask.end());
17317+ for (auto [Idx, I] : enumerate(BVMask))
17318+ if (I != PoisonMaskElem)
17319+ NewMask[Idx] = I;
17320+ BVCost += ::getShuffleCost(*TTI, TTI::SK_PermuteSingleSrc,
17321+ VecTy, NewMask, CostKind);
17322+ }
17323+ return SplatCost <= BVCost;
17324+ };
17325+ if (!IsSplat || Mask.size() <= 2 || !CheckIfSplatIsProfitable()) {
17326+ for (auto [Idx, I] : enumerate(BVMask))
17327+ if (I != PoisonMaskElem)
17328+ Mask[Idx] = I;
17329+ Vec = ShuffleBuilder.gather(NonConstants, Mask.size(), Vec);
17330+ } else {
17331+ Value *V = *find_if_not(NonConstants, IsaPred<UndefValue>);
17332+ SmallVector<Value *> Values(NonConstants.size(), PoisonValue::get(ScalarTy));
17333+ Values[0] = V;
17334+ Value *BV = ShuffleBuilder.gather(Values, BVMask.size());
17335+ SmallVector<int> SplatMask(BVMask.size(), PoisonMaskElem);
17336+ transform(BVMask, SplatMask.begin(), [](int I) {
17337+ return I == PoisonMaskElem ? PoisonMaskElem : 0;
17338+ });
17339+ if (!ShuffleVectorInst::isIdentityMask(SplatMask, VF))
17340+ BV = CreateShuffle(BV, nullptr, SplatMask);
17341+ for (auto [Idx, I] : enumerate(BVMask))
17342+ if (I != PoisonMaskElem)
17343+ Mask[Idx] = BVMask.size() + Idx;
17344+ Vec = CreateShuffle(Vec, BV, Mask);
17345+ for (auto [Idx, I] : enumerate(Mask))
17346+ if (I != PoisonMaskElem)
17347+ Mask[Idx] = Idx;
17348+ }
1728317349 });
1728417350 } else if (!allConstant(GatheredScalars)) {
1728517351 // Gather unique scalars and all constants.
0 commit comments