@@ -12613,11 +12613,13 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
1261312613 }
1261412614 InstructionCost createFreeze(InstructionCost Cost) { return Cost; }
1261512615 /// Finalize emission of the shuffles.
12616- InstructionCost
12617- finalize(ArrayRef<int> ExtMask,
12618- ArrayRef<std::pair<const TreeEntry *, unsigned>> SubVectors,
12619- ArrayRef<int> SubVectorsMask, unsigned VF = 0,
12620- function_ref<void(Value *&, SmallVectorImpl<int> &)> Action = {}) {
12616+ InstructionCost finalize(
12617+ ArrayRef<int> ExtMask,
12618+ ArrayRef<std::pair<const TreeEntry *, unsigned>> SubVectors,
12619+ ArrayRef<int> SubVectorsMask, unsigned VF = 0,
12620+ function_ref<void(Value *&, SmallVectorImpl<int> &,
12621+ function_ref<Value *(Value *, Value *, ArrayRef<int>)>)>
12622+ Action = {}) {
1262112623 IsFinalized = true;
1262212624 if (Action) {
1262312625 const PointerUnion<Value *, const TreeEntry *> &Vec = InVectors.front();
@@ -12629,7 +12631,10 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
1262912631 assert(VF > 0 &&
1263012632 "Expected vector length for the final value before action.");
1263112633 Value *V = cast<Value *>(Vec);
12632- Action(V, CommonMask);
12634+ Action(V, CommonMask, [this](Value *V1, Value *V2, ArrayRef<int> Mask) {
12635+ Cost += createShuffle(V1, V2, Mask);
12636+ return V1;
12637+ });
1263312638 InVectors.front() = V;
1263412639 }
1263512640 if (!SubVectors.empty()) {
@@ -16593,11 +16598,13 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
1659316598 /// Finalize emission of the shuffles.
1659416599 /// \param Action the action (if any) to be performed before final applying of
1659516600 /// the \p ExtMask mask.
16596- Value *
16597- finalize(ArrayRef<int> ExtMask,
16598- ArrayRef<std::pair<const TreeEntry *, unsigned>> SubVectors,
16599- ArrayRef<int> SubVectorsMask, unsigned VF = 0,
16600- function_ref<void(Value *&, SmallVectorImpl<int> &)> Action = {}) {
16601+ Value *finalize(
16602+ ArrayRef<int> ExtMask,
16603+ ArrayRef<std::pair<const TreeEntry *, unsigned>> SubVectors,
16604+ ArrayRef<int> SubVectorsMask, unsigned VF = 0,
16605+ function_ref<void(Value *&, SmallVectorImpl<int> &,
16606+ function_ref<Value *(Value *, Value *, ArrayRef<int>)>)>
16607+ Action = {}) {
1660116608 IsFinalized = true;
1660216609 if (Action) {
1660316610 Value *Vec = InVectors.front();
@@ -16616,7 +16623,9 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
1661616623 std::iota(ResizeMask.begin(), std::next(ResizeMask.begin(), VecVF), 0);
1661716624 Vec = createShuffle(Vec, nullptr, ResizeMask);
1661816625 }
16619- Action(Vec, CommonMask);
16626+ Action(Vec, CommonMask, [this](Value *V1, Value *V2, ArrayRef<int> Mask) {
16627+ return createShuffle(V1, V2, Mask);
16628+ });
1662016629 InVectors.front() = Vec;
1662116630 }
1662216631 if (!SubVectors.empty()) {
@@ -17278,9 +17287,67 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Type *ScalarTy,
1727817287 else
1727917288 Res = ShuffleBuilder.finalize(
1728017289 E->ReuseShuffleIndices, SubVectors, SubVectorsMask, E->Scalars.size(),
17281- [&](Value *&Vec, SmallVectorImpl<int> &Mask) {
17282- TryPackScalars(NonConstants, Mask, /*IsRootPoison=*/false);
17283- Vec = ShuffleBuilder.gather(NonConstants, Mask.size(), Vec);
17290+ [&](Value *&Vec, SmallVectorImpl<int> &Mask, auto CreateShuffle) {
17291+ bool IsSplat = isSplat(NonConstants);
17292+ SmallVector<int> BVMask(Mask.size(), PoisonMaskElem);
17293+ TryPackScalars(NonConstants, BVMask, /*IsRootPoison=*/false);
17294+ auto CheckIfSplatIsProfitable = [&]() {
17295+ // Estimate the cost of splatting + shuffle and compare with
17296+ // insert + shuffle.
17297+ constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
17298+ Value *V = *find_if_not(NonConstants, IsaPred<UndefValue>);
17299+ if (isa<ExtractElementInst>(V) || isVectorized(V))
17300+ return false;
17301+ InstructionCost SplatCost = TTI->getVectorInstrCost(
17302+ Instruction::InsertElement, VecTy, CostKind, /*Index=*/0,
17303+ PoisonValue::get(VecTy), V);
17304+ SmallVector<int> NewMask(Mask.begin(), Mask.end());
17305+ for (auto [Idx, I] : enumerate(BVMask))
17306+ if (I != PoisonMaskElem)
17307+ NewMask[Idx] = Mask.size();
17308+ SplatCost += ::getShuffleCost(*TTI, TTI::SK_PermuteTwoSrc, VecTy,
17309+ NewMask, CostKind);
17310+ InstructionCost BVCost = TTI->getVectorInstrCost(
17311+ Instruction::InsertElement, VecTy, CostKind,
17312+ *find_if(Mask, [](int I) { return I != PoisonMaskElem; }),
17313+ Vec, V);
17314+ // Shuffle required?
17315+ if (count(BVMask, PoisonMaskElem) <
17316+ static_cast<int>(BVMask.size() - 1)) {
17317+ SmallVector<int> NewMask(Mask.begin(), Mask.end());
17318+ for (auto [Idx, I] : enumerate(BVMask))
17319+ if (I != PoisonMaskElem)
17320+ NewMask[Idx] = I;
17321+ BVCost += ::getShuffleCost(*TTI, TTI::SK_PermuteSingleSrc,
17322+ VecTy, NewMask, CostKind);
17323+ }
17324+ return SplatCost <= BVCost;
17325+ };
17326+ if (!IsSplat || Mask.size() <= 2 || !CheckIfSplatIsProfitable()) {
17327+ for (auto [Idx, I] : enumerate(BVMask))
17328+ if (I != PoisonMaskElem)
17329+ Mask[Idx] = I;
17330+ Vec = ShuffleBuilder.gather(NonConstants, Mask.size(), Vec);
17331+ } else {
17332+ Value *V = *find_if_not(NonConstants, IsaPred<UndefValue>);
17333+ SmallVector<Value *> Values(NonConstants.size(),
17334+ PoisonValue::get(ScalarTy));
17335+ Values[0] = V;
17336+ Value *BV = ShuffleBuilder.gather(Values, BVMask.size());
17337+ SmallVector<int> SplatMask(BVMask.size(), PoisonMaskElem);
17338+ transform(BVMask, SplatMask.begin(), [](int I) {
17339+ return I == PoisonMaskElem ? PoisonMaskElem : 0;
17340+ });
17341+ if (!ShuffleVectorInst::isIdentityMask(SplatMask, VF))
17342+ BV = CreateShuffle(BV, nullptr, SplatMask);
17343+ for (auto [Idx, I] : enumerate(BVMask))
17344+ if (I != PoisonMaskElem)
17345+ Mask[Idx] = BVMask.size() + Idx;
17346+ Vec = CreateShuffle(Vec, BV, Mask);
17347+ for (auto [Idx, I] : enumerate(Mask))
17348+ if (I != PoisonMaskElem)
17349+ Mask[Idx] = Idx;
17350+ }
1728417351 });
1728517352 } else if (!allConstant(GatheredScalars)) {
1728617353 // Gather unique scalars and all constants.
0 commit comments