@@ -7351,6 +7351,32 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
73517351 V2 = getAllOnesValue(
73527352 *R.DL,
73537353 FixedVectorType::get(E2->Scalars.front()->getType(), CommonVF));
7354+ } else if (!V1 && V2) {
7355+ // Shuffle vector and tree node.
7356+ unsigned VF = cast<FixedVectorType>(V2->getType())->getNumElements();
7357+ const TreeEntry *E1 = P1.get<const TreeEntry *>();
7358+ CommonVF = std::max(VF, E1->getVectorFactor());
7359+ assert(all_of(Mask,
7360+ [=](int Idx) {
7361+ return Idx < 2 * static_cast<int>(CommonVF);
7362+ }) &&
7363+ "All elements in mask must be less than 2 * CommonVF.");
7364+ if (E1->Scalars.size() == VF && VF != CommonVF) {
7365+ SmallVector<int> E1Mask = E1->getCommonMask();
7366+ assert(!E1Mask.empty() && "Expected non-empty common mask.");
7367+ for (int &Idx : CommonMask) {
7368+ if (Idx == PoisonMaskElem)
7369+ continue;
7370+ if (Idx >= static_cast<int>(CommonVF))
7371+ Idx = E1Mask[Idx - CommonVF] + VF;
7372+ }
7373+ CommonVF = VF;
7374+ }
7375+ V1 = Constant::getNullValue(
7376+ FixedVectorType::get(E1->Scalars.front()->getType(), CommonVF));
7377+ V2 = getAllOnesValue(
7378+ *R.DL,
7379+ FixedVectorType::get(E1->Scalars.front()->getType(), CommonVF));
73547380 } else {
73557381 assert(V1 && V2 && "Expected both vectors.");
73567382 unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements();
@@ -7387,7 +7413,8 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
73877413 R(R), CheckedExtracts(CheckedExtracts) {}
73887414 Value *adjustExtracts(const TreeEntry *E, MutableArrayRef<int> Mask,
73897415 ArrayRef<std::optional<TTI::ShuffleKind>> ShuffleKinds,
7390- unsigned NumParts) {
7416+ unsigned NumParts, bool &UseVecBaseAsInput) {
7417+ UseVecBaseAsInput = false;
73917418 if (Mask.empty())
73927419 return nullptr;
73937420 Value *VecBase = nullptr;
@@ -7410,6 +7437,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
74107437 Data.value() == VL[Data.index()]);
74117438 });
74127439 });
7440+ SmallPtrSet<Value *, 4> UniqueBases;
74137441 unsigned SliceSize = VL.size() / NumParts;
74147442 for (unsigned Part = 0; Part < NumParts; ++Part) {
74157443 ArrayRef<int> SubMask = Mask.slice(Part * SliceSize, SliceSize);
@@ -7424,13 +7452,14 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
74247452 // vectorized tree.
74257453 // Also, avoid adjusting the cost for extractelements with multiple uses
74267454 // in different graph entries.
7455+ auto *EE = cast<ExtractElementInst>(V);
7456+ VecBase = EE->getVectorOperand();
7457+ UniqueBases.insert(VecBase);
74277458 const TreeEntry *VE = R.getTreeEntry(V);
74287459 if (!CheckedExtracts.insert(V).second ||
74297460 !R.areAllUsersVectorized(cast<Instruction>(V), &VectorizedVals) ||
74307461 (VE && VE != E))
74317462 continue;
7432- auto *EE = cast<ExtractElementInst>(V);
7433- VecBase = EE->getVectorOperand();
74347463 std::optional<unsigned> EEIdx = getExtractIndex(EE);
74357464 if (!EEIdx)
74367465 continue;
@@ -7469,6 +7498,11 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
74697498 CommonMask.assign(Mask.begin(), Mask.end());
74707499 transformMaskAfterShuffle(CommonMask, CommonMask);
74717500 SameNodesEstimated = false;
7501+ if (NumParts != 1 && UniqueBases.size() != 1) {
7502+ UseVecBaseAsInput = true;
7503+ VecBase = Constant::getNullValue(
7504+ FixedVectorType::get(VL.front()->getType(), CommonMask.size()));
7505+ }
74727506 return VecBase;
74737507 }
74747508 void add(const TreeEntry &E1, const TreeEntry &E2, ArrayRef<int> Mask) {
@@ -7518,19 +7552,70 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
75187552 if (!SameNodesEstimated && InVectors.size() == 1)
75197553 InVectors.emplace_back(&E1);
75207554 }
7555+ /// Adds 2 input vectors and the mask for their shuffling.
7556+ void add(Value *V1, Value *V2, ArrayRef<int> Mask) {
7557+ // May come only for shuffling of 2 vectors with extractelements, already
7558+ // handled in adjustExtracts.
7559+ assert(InVectors.size() == 1 &&
7560+ all_of(enumerate(CommonMask),
7561+ [&](auto P) {
7562+ if (P.value() == PoisonMaskElem)
7563+ return Mask[P.index()] == PoisonMaskElem;
7564+ auto *EI =
7565+ cast<ExtractElementInst>(InVectors.front()
7566+ .get<const TreeEntry *>()
7567+ ->Scalars[P.index()]);
7568+ return EI->getVectorOperand() == V1 ||
7569+ EI->getVectorOperand() == V2;
7570+ }) &&
7571+ "Expected extractelement vectors.");
7572+ }
75217573 /// Adds another one input vector and the mask for the shuffling.
7522- void add(Value *V1, ArrayRef<int> Mask) {
7574+ void add(Value *V1, ArrayRef<int> Mask, bool ForExtracts = false ) {
75237575 if (InVectors.empty()) {
7524- assert(CommonMask.empty() && "Expected empty input mask/vectors.");
7576+ assert(CommonMask.empty() && !ForExtracts &&
7577+ "Expected empty input mask/vectors.");
75257578 CommonMask.assign(Mask.begin(), Mask.end());
75267579 InVectors.assign(1, V1);
75277580 return;
75287581 }
7529- assert(InVectors.size() == 1 && InVectors.front().is<const TreeEntry *>() &&
7530- !CommonMask.empty() && "Expected only single entry from extracts.");
7582+ if (ForExtracts) {
7583+ // No need to add vectors here, already handled them in adjustExtracts.
7584+ assert(InVectors.size() == 1 &&
7585+ InVectors.front().is<const TreeEntry *>() && !CommonMask.empty() &&
7586+ all_of(enumerate(CommonMask),
7587+ [&](auto P) {
7588+ Value *Scalar = InVectors.front()
7589+ .get<const TreeEntry *>()
7590+ ->Scalars[P.index()];
7591+ if (P.value() == PoisonMaskElem)
7592+ return P.value() == Mask[P.index()] ||
7593+ isa<UndefValue>(Scalar);
7594+ if (isa<Constant>(V1))
7595+ return true;
7596+ auto *EI = cast<ExtractElementInst>(Scalar);
7597+ return EI->getVectorOperand() == V1;
7598+ }) &&
7599+ "Expected only tree entry for extractelement vectors.");
7600+ return;
7601+ }
7602+ assert(!InVectors.empty() && !CommonMask.empty() &&
7603+ "Expected only tree entries from extracts/reused buildvectors.");
7604+ unsigned VF = cast<FixedVectorType>(V1->getType())->getNumElements();
7605+ if (InVectors.size() == 2) {
7606+ Cost += createShuffle(InVectors.front(), InVectors.back(), CommonMask);
7607+ transformMaskAfterShuffle(CommonMask, CommonMask);
7608+ VF = std::max<unsigned>(VF, CommonMask.size());
7609+ } else if (const auto *InTE =
7610+ InVectors.front().dyn_cast<const TreeEntry *>()) {
7611+ VF = std::max(VF, InTE->getVectorFactor());
7612+ } else {
7613+ VF = std::max(
7614+ VF, cast<FixedVectorType>(InVectors.front().get<Value *>()->getType())
7615+ ->getNumElements());
7616+ }
75317617 InVectors.push_back(V1);
7532- unsigned VF = CommonMask.size();
7533- for (unsigned Idx = 0; Idx < VF; ++Idx)
7618+ for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx)
75347619 if (Mask[Idx] != PoisonMaskElem && CommonMask[Idx] == PoisonMaskElem)
75357620 CommonMask[Idx] = Mask[Idx] + VF;
75367621 }
@@ -7666,6 +7751,8 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
76667751 reorderScalars(GatheredScalars, ReorderMask);
76677752 SmallVector<int> Mask;
76687753 SmallVector<int> ExtractMask;
7754+ Value *ExtractVecBase = nullptr;
7755+ bool UseVecBaseAsInput = false;
76697756 SmallVector<std::optional<TargetTransformInfo::ShuffleKind>> GatherShuffles;
76707757 SmallVector<SmallVector<const TreeEntry *>> Entries;
76717758 SmallVector<std::optional<TTI::ShuffleKind>> ExtractShuffles;
@@ -7679,7 +7766,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
76797766 tryToGatherExtractElements(GatheredScalars, ExtractMask, NumParts);
76807767 if (!ExtractShuffles.empty()) {
76817768 if (Value *VecBase = Estimator.adjustExtracts(
7682- E, ExtractMask, ExtractShuffles, NumParts)) {
7769+ E, ExtractMask, ExtractShuffles, NumParts, UseVecBaseAsInput )) {
76837770 if (auto *VecBaseTy = dyn_cast<FixedVectorType>(VecBase->getType()))
76847771 if (VF == VecBaseTy->getNumElements() &&
76857772 GatheredScalars.size() != VF) {
@@ -7774,6 +7861,48 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
77747861 ScalarTy, GatheredScalars.size())));
77757862 });
77767863 }
7864+ if (!ExtractShuffles.empty()) {
7865+ Value *Vec1 = nullptr;
7866+ // Gather of extractelements can be represented as just a shuffle of
7867+ // a single/two vectors the scalars are extracted from.
7868+ // Find input vectors.
7869+ Value *Vec2 = nullptr;
7870+ for (unsigned I = 0, Sz = ExtractMask.size(); I < Sz; ++I) {
7871+ if (!Mask.empty() && Mask[I] != PoisonMaskElem)
7872+ ExtractMask[I] = PoisonMaskElem;
7873+ }
7874+ if (UseVecBaseAsInput) {
7875+ Vec1 = ExtractVecBase;
7876+ } else {
7877+ for (unsigned I = 0, Sz = ExtractMask.size(); I < Sz; ++I) {
7878+ if (ExtractMask[I] == PoisonMaskElem)
7879+ continue;
7880+ if (isa<UndefValue>(E->Scalars[I]))
7881+ continue;
7882+ auto *EI = cast<ExtractElementInst>(E->Scalars[I]);
7883+ Value *VecOp = EI->getVectorOperand();
7884+ if (const auto *TE = getTreeEntry(VecOp))
7885+ if (TE->VectorizedValue)
7886+ VecOp = TE->VectorizedValue;
7887+ if (!Vec1) {
7888+ Vec1 = VecOp;
7889+ } else if (Vec1 != EI->getVectorOperand()) {
7890+ assert((!Vec2 || Vec2 == EI->getVectorOperand()) &&
7891+ "Expected only 1 or 2 vectors shuffle.");
7892+ Vec2 = VecOp;
7893+ }
7894+ }
7895+ }
7896+ if (Vec2) {
7897+ Estimator.add(Vec1, Vec2, ExtractMask);
7898+ } else if (Vec1) {
7899+ Estimator.add(Vec1, ExtractMask, /*ForExtracts=*/true);
7900+ } else {
7901+ Estimator.add(PoisonValue::get(FixedVectorType::get(
7902+ ScalarTy, GatheredScalars.size())),
7903+ ExtractMask, /*ForExtracts=*/true);
7904+ }
7905+ }
77777906 if (!all_of(GatheredScalars, PoisonValue::classof)) {
77787907 auto Gathers = ArrayRef(GatheredScalars).take_front(VL.size());
77797908 bool SameGathers = VL.equals(Gathers);
@@ -10367,7 +10496,7 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
1036710496 InVectors.push_back(V1);
1036810497 }
1036910498 /// Adds another one input vector and the mask for the shuffling.
10370- void add(Value *V1, ArrayRef<int> Mask) {
10499+ void add(Value *V1, ArrayRef<int> Mask, bool = false ) {
1037110500 if (InVectors.empty()) {
1037210501 if (!isa<FixedVectorType>(V1->getType())) {
1037310502 V1 = createShuffle(V1, nullptr, CommonMask);
@@ -10906,13 +11035,13 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Args &...Params) {
1090611035 IsUsedInExpr &= FindReusedSplat(
1090711036 ExtractMask,
1090811037 cast<FixedVectorType>(Vec1->getType())->getNumElements());
10909- ShuffleBuilder.add(Vec1, ExtractMask);
11038+ ShuffleBuilder.add(Vec1, ExtractMask, /*ForExtracts=*/true );
1091011039 IsNonPoisoned &= isGuaranteedNotToBePoison(Vec1);
1091111040 } else {
1091211041 IsUsedInExpr = false;
1091311042 ShuffleBuilder.add(PoisonValue::get(FixedVectorType::get(
1091411043 ScalarTy, GatheredScalars.size())),
10915- ExtractMask);
11044+ ExtractMask, /*ForExtracts=*/true );
1091611045 }
1091711046 }
1091811047 if (!GatherShuffles.empty()) {
0 commit comments