diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h index 7198e134a2d26..54973a7749f3f 100644 --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -2535,7 +2535,19 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase { unsigned getNumberOfParts(Type *Tp) { std::pair LT = getTypeLegalizationCost(Tp); - return LT.first.isValid() ? *LT.first.getValue() : 0; + if (!LT.first.isValid()) + return 0; + // Try to find actual number of parts for non-power-of-2 elements as + // ceil(num-of-elements/num-of-subtype-elements). + if (auto *FTp = dyn_cast(Tp); + Tp && LT.second.isFixedLengthVector() && + !has_single_bit(FTp->getNumElements())) { + if (auto *SubTp = dyn_cast_if_present( + EVT(LT.second).getTypeForEVT(Tp->getContext())); + SubTp && SubTp->getElementType() == FTp->getElementType()) + return divideCeil(FTp->getNumElements(), SubTp->getNumElements()); + } + return *LT.first.getValue(); } InstructionCost getAddressComputationCost(Type *Ty, ScalarEvolution *, diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 3695a8082531c..56c829337a1e0 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -260,6 +260,20 @@ static FixedVectorType *getWidenedType(Type *ScalarTy, unsigned VF) { VF * getNumElements(ScalarTy)); } +/// Returns the number of elements of the given type \p Ty, not less than \p Sz, +/// which forms type, which splits by \p TTI into whole vector types during +/// legalization. +static unsigned getFullVectorNumberOfElements(const TargetTransformInfo &TTI, + Type *Ty, unsigned Sz) { + if (!isValidElementType(Ty)) + return bit_ceil(Sz); + // Find the number of elements, which forms full vectors. + const unsigned NumParts = TTI.getNumberOfParts(getWidenedType(Ty, Sz)); + if (NumParts == 0 || NumParts >= Sz) + return bit_ceil(Sz); + return bit_ceil(divideCeil(Sz, NumParts)) * NumParts; +} + static void transformScalarShuffleIndiciesToVector(unsigned VecTyNumElements, SmallVectorImpl &Mask) { // The ShuffleBuilder implementation use shufflevector to splat an "element". @@ -394,7 +408,7 @@ static bool isVectorLikeInstWithConstOps(Value *V) { /// total number of elements \p Size and number of registers (parts) \p /// NumParts. static unsigned getPartNumElems(unsigned Size, unsigned NumParts) { - return PowerOf2Ceil(divideCeil(Size, NumParts)); + return std::min(Size, bit_ceil(divideCeil(Size, NumParts))); } /// Returns correct remaining number of elements, considering total amount \p @@ -1222,6 +1236,22 @@ static bool doesNotNeedToSchedule(ArrayRef VL) { (all_of(VL, isUsedOutsideBlock) || all_of(VL, areAllOperandsNonInsts)); } +/// Returns true if widened type of \p Ty elements with size \p Sz represents +/// full vector type, i.e. adding extra element results in extra parts upon type +/// legalization. +static bool hasFullVectorsOrPowerOf2(const TargetTransformInfo &TTI, Type *Ty, + unsigned Sz) { + if (Sz <= 1) + return false; + if (!isValidElementType(Ty) && !isa(Ty)) + return false; + if (has_single_bit(Sz)) + return true; + const unsigned NumParts = TTI.getNumberOfParts(getWidenedType(Ty, Sz)); + return NumParts > 0 && NumParts < Sz && has_single_bit(Sz / NumParts) && + Sz % NumParts == 0; +} + namespace slpvectorizer { /// Bottom Up SLP Vectorizer. @@ -3311,6 +3341,15 @@ class BoUpSLP { /// Return true if this is a non-power-of-2 node. bool isNonPowOf2Vec() const { bool IsNonPowerOf2 = !has_single_bit(Scalars.size()); + return IsNonPowerOf2; + } + + /// Return true if this is a node, which tries to vectorize number of + /// elements, forming whole vectors. + bool + hasNonWholeRegisterOrNonPowerOf2Vec(const TargetTransformInfo &TTI) const { + bool IsNonPowerOf2 = !hasFullVectorsOrPowerOf2( + TTI, getValueType(Scalars.front()), Scalars.size()); assert((!IsNonPowerOf2 || ReuseShuffleIndices.empty()) && "Reshuffling not supported with non-power-of-2 vectors yet."); return IsNonPowerOf2; @@ -3430,8 +3469,10 @@ class BoUpSLP { Last->State = EntryState; // FIXME: Remove once support for ReuseShuffleIndices has been implemented // for non-power-of-two vectors. - assert((has_single_bit(VL.size()) || ReuseShuffleIndices.empty()) && - "Reshuffling scalars not yet supported for nodes with padding"); + assert( + (hasFullVectorsOrPowerOf2(*TTI, getValueType(VL.front()), VL.size()) || + ReuseShuffleIndices.empty()) && + "Reshuffling scalars not yet supported for nodes with padding"); Last->ReuseShuffleIndices.append(ReuseShuffleIndices.begin(), ReuseShuffleIndices.end()); if (ReorderIndices.empty()) { @@ -5269,7 +5310,7 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) { // node. if (!TE.ReuseShuffleIndices.empty()) { // FIXME: Support ReuseShuffleIndices for non-power-of-two vectors. - assert(!TE.isNonPowOf2Vec() && + assert(!TE.hasNonWholeRegisterOrNonPowerOf2Vec(*TTI) && "Reshuffling scalars not yet supported for nodes with padding"); if (isSplat(TE.Scalars)) @@ -5509,7 +5550,7 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) { } // FIXME: Remove the non-power-of-two check once findReusedOrderedScalars // has been auditted for correctness with non-power-of-two vectors. - if (!TE.isNonPowOf2Vec()) + if (!TE.hasNonWholeRegisterOrNonPowerOf2Vec(*TTI)) if (std::optional CurrentOrder = findReusedOrderedScalars(TE)) return CurrentOrder; } @@ -5662,8 +5703,8 @@ void BoUpSLP::reorderTopToBottom() { }); // Reorder the graph nodes according to their vectorization factor. - for (unsigned VF = VectorizableTree.front()->getVectorFactor(); VF > 1; - VF = bit_ceil(VF) / 2) { + for (unsigned VF = VectorizableTree.front()->getVectorFactor(); + !VFToOrderedEntries.empty() && VF > 1; VF -= 2 - (VF & 1U)) { auto It = VFToOrderedEntries.find(VF); if (It == VFToOrderedEntries.end()) continue; @@ -5671,6 +5712,9 @@ void BoUpSLP::reorderTopToBottom() { // used order and reorder scalar elements in the nodes according to this // mostly used order. ArrayRef OrderedEntries = It->second.getArrayRef(); + // Delete VF entry upon exit. + auto Cleanup = make_scope_exit([&]() { VFToOrderedEntries.erase(It); }); + // All operands are reordered and used only in this node - propagate the // most used order to the user node. MapVector VL, unsigned Depth, UniqueValues.emplace_back(V); } size_t NumUniqueScalarValues = UniqueValues.size(); - if (NumUniqueScalarValues == VL.size()) { + bool IsFullVectors = hasFullVectorsOrPowerOf2( + *TTI, UniqueValues.front()->getType(), NumUniqueScalarValues); + if (NumUniqueScalarValues == VL.size() && + (VectorizeNonPowerOf2 || IsFullVectors)) { ReuseShuffleIndices.clear(); } else { // FIXME: Reshuffing scalars is not supported yet for non-power-of-2 ops. - if ((UserTreeIdx.UserTE && UserTreeIdx.UserTE->isNonPowOf2Vec()) || - !llvm::has_single_bit(VL.size())) { + if ((UserTreeIdx.UserTE && + UserTreeIdx.UserTE->hasNonWholeRegisterOrNonPowerOf2Vec(*TTI)) || + !has_single_bit(VL.size())) { LLVM_DEBUG(dbgs() << "SLP: Reshuffling scalars not yet supported " "for nodes with padding.\n"); newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx); return false; } LLVM_DEBUG(dbgs() << "SLP: Shuffle for reused scalars.\n"); - if (NumUniqueScalarValues <= 1 || - (UniquePositions.size() == 1 && all_of(UniqueValues, - [](Value *V) { - return isa(V) || - !isConstant(V); - })) || - !llvm::has_single_bit(NumUniqueScalarValues)) { + if (NumUniqueScalarValues <= 1 || !IsFullVectors || + (UniquePositions.size() == 1 && all_of(UniqueValues, [](Value *V) { + return isa(V) || !isConstant(V); + }))) { if (DoNotFail && UniquePositions.size() > 1 && NumUniqueScalarValues > 1 && S.MainOp->isSafeToRemove() && all_of(UniqueValues, [=](Value *V) { @@ -7555,7 +7600,9 @@ void BoUpSLP::buildTree_rec(ArrayRef VL, unsigned Depth, areAllUsersVectorized(cast(V), UserIgnoreList); })) { - unsigned PWSz = PowerOf2Ceil(UniqueValues.size()); + // Find the number of elements, which forms full vectors. + unsigned PWSz = getFullVectorNumberOfElements( + *TTI, UniqueValues.front()->getType(), UniqueValues.size()); if (PWSz == VL.size()) { ReuseShuffleIndices.clear(); } else { @@ -9793,9 +9840,6 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { return nullptr; Value *VecBase = nullptr; ArrayRef VL = E->Scalars; - // If the resulting type is scalarized, do not adjust the cost. - if (NumParts == VL.size()) - return nullptr; // Check if it can be considered reused if same extractelements were // vectorized already. bool PrevNodeFound = any_of( @@ -10449,7 +10493,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef VectorizedVals, InsertMask[Idx] = I + 1; } unsigned VecScalarsSz = PowerOf2Ceil(NumElts); - if (NumOfParts > 0) + if (NumOfParts > 0 && NumOfParts < NumElts) VecScalarsSz = PowerOf2Ceil((NumElts + NumOfParts - 1) / NumOfParts); unsigned VecSz = (1 + OffsetEnd / VecScalarsSz - OffsetBeg / VecScalarsSz) * VecScalarsSz; @@ -17778,7 +17822,7 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef VL, BoUpSLP &R, for (unsigned I = NextInst; I < MaxInst; ++I) { unsigned ActualVF = std::min(MaxInst - I, VF); - if (!has_single_bit(ActualVF)) + if (!hasFullVectorsOrPowerOf2(*TTI, ScalarTy, ActualVF)) continue; if (MaxVFOnly && ActualVF < MaxVF) diff --git a/llvm/test/Transforms/SLPVectorizer/reduction-whole-regs-loads.ll b/llvm/test/Transforms/SLPVectorizer/reduction-whole-regs-loads.ll index 281b5f99540ea..4074b8654362e 100644 --- a/llvm/test/Transforms/SLPVectorizer/reduction-whole-regs-loads.ll +++ b/llvm/test/Transforms/SLPVectorizer/reduction-whole-regs-loads.ll @@ -1,21 +1,29 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py -; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=riscv64-unknown-linux -mattr=+v -slp-threshold=-100 | FileCheck %s +; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=riscv64-unknown-linux -mattr=+v -slp-threshold=-100 | FileCheck %s --check-prefix=RISCV ; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=x86_64-unknown-linux -slp-threshold=-100 | FileCheck %s ; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=aarch64-unknown-linux -slp-threshold=-100 | FileCheck %s ; REQUIRES: aarch64-registered-target, x86-registered-target, riscv-registered-target define i64 @test(ptr %p) { +; RISCV-LABEL: @test( +; RISCV-NEXT: entry: +; RISCV-NEXT: [[ARRAYIDX_4:%.*]] = getelementptr inbounds i64, ptr [[P:%.*]], i64 4 +; RISCV-NEXT: [[TMP0:%.*]] = load <4 x i64>, ptr [[P]], align 4 +; RISCV-NEXT: [[TMP1:%.*]] = load <2 x i64>, ptr [[ARRAYIDX_4]], align 4 +; RISCV-NEXT: [[TMP2:%.*]] = shufflevector <4 x i64> [[TMP0]], <4 x i64> poison, <8 x i32> +; RISCV-NEXT: [[TMP3:%.*]] = call <8 x i64> @llvm.vector.insert.v8i64.v4i64(<8 x i64> [[TMP2]], <4 x i64> [[TMP0]], i64 0) +; RISCV-NEXT: [[TMP4:%.*]] = call <8 x i64> @llvm.vector.insert.v8i64.v2i64(<8 x i64> [[TMP3]], <2 x i64> [[TMP1]], i64 4) +; RISCV-NEXT: [[TMP5:%.*]] = mul <8 x i64> [[TMP4]], +; RISCV-NEXT: [[TMP6:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP5]]) +; RISCV-NEXT: ret i64 [[TMP6]] +; ; CHECK-LABEL: @test( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[ARRAYIDX_4:%.*]] = getelementptr inbounds i64, ptr [[P:%.*]], i64 4 -; CHECK-NEXT: [[TMP0:%.*]] = load <4 x i64>, ptr [[P]], align 4 -; CHECK-NEXT: [[TMP1:%.*]] = load <2 x i64>, ptr [[ARRAYIDX_4]], align 4 -; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i64> [[TMP0]], <4 x i64> poison, <8 x i32> -; CHECK-NEXT: [[TMP3:%.*]] = call <8 x i64> @llvm.vector.insert.v8i64.v4i64(<8 x i64> [[TMP2]], <4 x i64> [[TMP0]], i64 0) -; CHECK-NEXT: [[TMP4:%.*]] = call <8 x i64> @llvm.vector.insert.v8i64.v2i64(<8 x i64> [[TMP3]], <2 x i64> [[TMP1]], i64 4) -; CHECK-NEXT: [[TMP5:%.*]] = mul <8 x i64> [[TMP4]], -; CHECK-NEXT: [[TMP6:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP5]]) -; CHECK-NEXT: ret i64 [[TMP6]] +; CHECK-NEXT: [[TMP0:%.*]] = load <6 x i64>, ptr [[P:%.*]], align 4 +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <6 x i64> [[TMP0]], <6 x i64> poison, <8 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = mul <8 x i64> [[TMP1]], +; CHECK-NEXT: [[TMP3:%.*]] = call i64 @llvm.vector.reduce.add.v8i64(<8 x i64> [[TMP2]]) +; CHECK-NEXT: ret i64 [[TMP3]] ; entry: %arrayidx.1 = getelementptr inbounds i64, ptr %p, i64 1