@@ -529,11 +529,17 @@ isFixedVectorShuffle(ArrayRef<Value *> VL, SmallVectorImpl<int> &Mask) {
529529 const auto *It = find_if(VL, IsaPred<ExtractElementInst>);
530530 if (It == VL.end())
531531 return std::nullopt;
532- auto *EI0 = cast<ExtractElementInst>(*It);
533- if (isa<ScalableVectorType>(EI0->getVectorOperandType()))
534- return std::nullopt;
535532 unsigned Size =
536- cast<FixedVectorType>(EI0->getVectorOperandType())->getNumElements();
533+ std::accumulate(VL.begin(), VL.end(), 0u, [](unsigned S, Value *V) {
534+ auto *EI = dyn_cast<ExtractElementInst>(V);
535+ if (!EI)
536+ return S;
537+ auto *VTy = dyn_cast<FixedVectorType>(EI->getVectorOperandType());
538+ if (!VTy)
539+ return S;
540+ return std::max(S, VTy->getNumElements());
541+ });
542+
537543 Value *Vec1 = nullptr;
538544 Value *Vec2 = nullptr;
539545 bool HasNonUndefVec = any_of(VL, [](Value *V) {
@@ -563,8 +569,6 @@ isFixedVectorShuffle(ArrayRef<Value *> VL, SmallVectorImpl<int> &Mask) {
563569 if (isa<UndefValue>(Vec)) {
564570 Mask[I] = I;
565571 } else {
566- if (cast<FixedVectorType>(Vec->getType())->getNumElements() != Size)
567- return std::nullopt;
568572 if (isa<UndefValue>(EI->getIndexOperand()))
569573 continue;
570574 auto *Idx = dyn_cast<ConstantInt>(EI->getIndexOperand());
@@ -10657,36 +10661,20 @@ BoUpSLP::tryToGatherSingleRegisterExtractElements(
1065710661 VectorOpToIdx[EI->getVectorOperand()].push_back(I);
1065810662 }
1065910663 // Sort the vector operands by the maximum number of uses in extractelements.
10660- MapVector<unsigned, SmallVector<Value *>> VFToVector;
10661- for (const auto &Data : VectorOpToIdx)
10662- VFToVector[cast<FixedVectorType>(Data.first->getType())->getNumElements()]
10663- .push_back(Data.first);
10664- for (auto &Data : VFToVector) {
10665- stable_sort(Data.second, [&VectorOpToIdx](Value *V1, Value *V2) {
10666- return VectorOpToIdx.find(V1)->second.size() >
10667- VectorOpToIdx.find(V2)->second.size();
10668- });
10669- }
10670- // Find the best pair of the vectors with the same number of elements or a
10671- // single vector.
10664+ SmallVector<std::pair<Value *, SmallVector<int>>> Vectors =
10665+ VectorOpToIdx.takeVector();
10666+ stable_sort(Vectors, [](const auto &P1, const auto &P2) {
10667+ return P1.second.size() > P2.second.size();
10668+ });
10669+ // Find the best pair of the vectors or a single vector.
1067210670 const int UndefSz = UndefVectorExtracts.size();
1067310671 unsigned SingleMax = 0;
10674- Value *SingleVec = nullptr;
1067510672 unsigned PairMax = 0;
10676- std::pair<Value *, Value *> PairVec(nullptr, nullptr);
10677- for (auto &Data : VFToVector) {
10678- Value *V1 = Data.second.front();
10679- if (SingleMax < VectorOpToIdx[V1].size() + UndefSz) {
10680- SingleMax = VectorOpToIdx[V1].size() + UndefSz;
10681- SingleVec = V1;
10682- }
10683- Value *V2 = nullptr;
10684- if (Data.second.size() > 1)
10685- V2 = *std::next(Data.second.begin());
10686- if (V2 && PairMax < VectorOpToIdx[V1].size() + VectorOpToIdx[V2].size() +
10687- UndefSz) {
10688- PairMax = VectorOpToIdx[V1].size() + VectorOpToIdx[V2].size() + UndefSz;
10689- PairVec = std::make_pair(V1, V2);
10673+ if (!Vectors.empty()) {
10674+ SingleMax = Vectors.front().second.size() + UndefSz;
10675+ if (Vectors.size() > 1) {
10676+ auto *ItNext = std::next(Vectors.begin());
10677+ PairMax = SingleMax + ItNext->second.size();
1069010678 }
1069110679 }
1069210680 if (SingleMax == 0 && PairMax == 0 && UndefSz == 0)
@@ -10697,11 +10685,11 @@ BoUpSLP::tryToGatherSingleRegisterExtractElements(
1069710685 SmallVector<Value *> GatheredExtracts(
1069810686 VL.size(), PoisonValue::get(VL.front()->getType()));
1069910687 if (SingleMax >= PairMax && SingleMax) {
10700- for (int Idx : VectorOpToIdx[SingleVec] )
10688+ for (int Idx : Vectors.front().second )
1070110689 std::swap(GatheredExtracts[Idx], VL[Idx]);
10702- } else {
10703- for (Value *V : {PairVec.first, PairVec.second })
10704- for (int Idx : VectorOpToIdx[V] )
10690+ } else if (!Vectors.empty()) {
10691+ for (unsigned Idx : {0, 1 })
10692+ for (int Idx : Vectors[Idx].second )
1070510693 std::swap(GatheredExtracts[Idx], VL[Idx]);
1070610694 }
1070710695 // Add extracts from undefs too.
@@ -11770,25 +11758,29 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
1177011758 MutableArrayRef<int> SubMask = Mask.slice(Part * SliceSize, Limit);
1177111759 constexpr int MaxBases = 2;
1177211760 SmallVector<Value *, MaxBases> Bases(MaxBases);
11773- #ifndef NDEBUG
11774- int PrevSize = 0;
11775- #endif // NDEBUG
11776- for (const auto [I, V]: enumerate(VL)) {
11777- if (SubMask[I] == PoisonMaskElem)
11761+ auto VLMask = zip(VL, SubMask);
11762+ const unsigned VF = std::accumulate(
11763+ VLMask.begin(), VLMask.end(), 0U, [&](unsigned S, const auto &D) {
11764+ if (std::get<1>(D) == PoisonMaskElem)
11765+ return S;
11766+ Value *VecOp =
11767+ cast<ExtractElementInst>(std::get<0>(D))->getVectorOperand();
11768+ if (const TreeEntry *TE = R.getTreeEntry(VecOp))
11769+ VecOp = TE->VectorizedValue;
11770+ assert(VecOp && "Expected vectorized value.");
11771+ const unsigned Size =
11772+ cast<FixedVectorType>(VecOp->getType())->getNumElements();
11773+ return std::max(S, Size);
11774+ });
11775+ for (const auto [V, I] : VLMask) {
11776+ if (I == PoisonMaskElem)
1177811777 continue;
1177911778 Value *VecOp = cast<ExtractElementInst>(V)->getVectorOperand();
1178011779 if (const TreeEntry *TE = R.getTreeEntry(VecOp))
1178111780 VecOp = TE->VectorizedValue;
1178211781 assert(VecOp && "Expected vectorized value.");
11783- const int Size =
11784- cast<FixedVectorType>(VecOp->getType())->getNumElements();
11785- #ifndef NDEBUG
11786- assert((PrevSize == Size || PrevSize == 0) &&
11787- "Expected vectors of the same size.");
11788- PrevSize = Size;
11789- #endif // NDEBUG
1179011782 VecOp = castToScalarTyElem(VecOp);
11791- Bases[SubMask[I] < Size ? 0 : 1 ] = VecOp;
11783+ Bases[I / VF ] = VecOp;
1179211784 }
1179311785 if (!Bases.front())
1179411786 continue;
@@ -11814,16 +11806,17 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
1181411806 "Expected first part or all previous parts masked.");
1181511807 copy(SubMask, std::next(VecMask.begin(), Part * SliceSize));
1181611808 } else {
11817- unsigned VF = cast<FixedVectorType>(Vec->getType())->getNumElements();
11809+ unsigned NewVF =
11810+ cast<FixedVectorType>(Vec->getType())->getNumElements();
1181811811 if (Vec->getType() != SubVec->getType()) {
1181911812 unsigned SubVecVF =
1182011813 cast<FixedVectorType>(SubVec->getType())->getNumElements();
11821- VF = std::max(VF , SubVecVF);
11814+ NewVF = std::max(NewVF , SubVecVF);
1182211815 }
1182311816 // Adjust SubMask.
1182411817 for (int &Idx : SubMask)
1182511818 if (Idx != PoisonMaskElem)
11826- Idx += VF ;
11819+ Idx += NewVF ;
1182711820 copy(SubMask, std::next(VecMask.begin(), Part * SliceSize));
1182811821 Vec = createShuffle(Vec, SubVec, VecMask);
1182911822 TransformToIdentity(VecMask);
0 commit comments