diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index 553df1c08f3ae..94de520a2715f 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -278,6 +278,22 @@ static unsigned getFullVectorNumberOfElements(const TargetTransformInfo &TTI, return bit_ceil(divideCeil(Sz, NumParts)) * NumParts; } +/// Returns the number of elements of the given type \p Ty, not greater than \p +/// Sz, which forms type, which splits by \p TTI into whole vector types during +/// legalization. +static unsigned +getFloorFullVectorNumberOfElements(const TargetTransformInfo &TTI, Type *Ty, + unsigned Sz) { + if (!isValidElementType(Ty)) + return bit_floor(Sz); + // Find the number of elements, which forms full vectors. + unsigned NumParts = TTI.getNumberOfParts(getWidenedType(Ty, Sz)); + if (NumParts == 0 || NumParts >= Sz) + return bit_floor(Sz); + unsigned RegVF = bit_ceil(divideCeil(Sz, NumParts)); + return (Sz / RegVF) * RegVF; +} + static void transformScalarShuffleIndiciesToVector(unsigned VecTyNumElements, SmallVectorImpl &Mask) { // The ShuffleBuilder implementation use shufflevector to splat an "element". @@ -7716,7 +7732,7 @@ void BoUpSLP::buildTree_rec(ArrayRef VL, unsigned Depth, } size_t NumUniqueScalarValues = UniqueValues.size(); bool IsFullVectors = hasFullVectorsOrPowerOf2( - *TTI, UniqueValues.front()->getType(), NumUniqueScalarValues); + *TTI, getValueType(UniqueValues.front()), NumUniqueScalarValues); if (NumUniqueScalarValues == VL.size() && (VectorizeNonPowerOf2 || IsFullVectors)) { ReuseShuffleIndices.clear(); @@ -17466,7 +17482,11 @@ SLPVectorizerPass::vectorizeStoreChain(ArrayRef Chain, BoUpSLP &R, const unsigned Sz = R.getVectorElementSize(Chain[0]); unsigned VF = Chain.size(); - if (!has_single_bit(Sz) || !has_single_bit(VF) || VF < 2 || VF < MinVF) { + if (!has_single_bit(Sz) || + !hasFullVectorsOrPowerOf2( + *TTI, cast(Chain.front())->getValueOperand()->getType(), + VF) || + VF < 2 || VF < MinVF) { // Check if vectorizing with a non-power-of-2 VF should be considered. At // the moment, only consider cases where VF + 1 is a power-of-2, i.e. almost // all vector lanes are used. @@ -17484,10 +17504,12 @@ SLPVectorizerPass::vectorizeStoreChain(ArrayRef Chain, BoUpSLP &R, InstructionsState S = getSameOpcode(ValOps.getArrayRef(), *TLI); if (all_of(ValOps, IsaPred) && ValOps.size() > 1) { DenseSet Stores(Chain.begin(), Chain.end()); - bool IsPowerOf2 = - has_single_bit(ValOps.size()) || + bool IsAllowedSize = + hasFullVectorsOrPowerOf2(*TTI, ValOps.front()->getType(), + ValOps.size()) || (VectorizeNonPowerOf2 && has_single_bit(ValOps.size() + 1)); - if ((!IsPowerOf2 && S.getOpcode() && S.getOpcode() != Instruction::Load && + if ((!IsAllowedSize && S.getOpcode() && + S.getOpcode() != Instruction::Load && (!S.MainOp->isSafeToRemove() || any_of(ValOps.getArrayRef(), [&](Value *V) { @@ -17498,7 +17520,7 @@ SLPVectorizerPass::vectorizeStoreChain(ArrayRef Chain, BoUpSLP &R, })); }))) || (ValOps.size() > Chain.size() / 2 && !S.getOpcode())) { - Size = (!IsPowerOf2 && S.getOpcode()) ? 1 : 2; + Size = (!IsAllowedSize && S.getOpcode()) ? 1 : 2; return false; } } @@ -17626,15 +17648,11 @@ bool SLPVectorizerPass::vectorizeStores( unsigned MaxVF = std::min(R.getMaximumVF(EltSize, Instruction::Store), MaxElts); - unsigned MaxRegVF = MaxVF; auto *Store = cast(Operands[0]); Type *StoreTy = Store->getValueOperand()->getType(); Type *ValueTy = StoreTy; if (auto *Trunc = dyn_cast(Store->getValueOperand())) ValueTy = Trunc->getSrcTy(); - if (ValueTy == StoreTy && - R.getVectorElementSize(Store->getValueOperand()) <= EltSize) - MaxVF = std::min(MaxVF, bit_floor(Operands.size())); unsigned MinVF = std::max( 2, PowerOf2Ceil(TTI->getStoreMinimumVF( R.getMinVF(DL->getTypeStoreSizeInBits(StoreTy)), StoreTy, @@ -17652,10 +17670,21 @@ bool SLPVectorizerPass::vectorizeStores( // First try vectorizing with a non-power-of-2 VF. At the moment, only // consider cases where VF + 1 is a power-of-2, i.e. almost all vector // lanes are used. - unsigned CandVF = - std::clamp(Operands.size(), MaxVF, MaxRegVF); - if (has_single_bit(CandVF + 1)) + unsigned CandVF = std::clamp(Operands.size(), MinVF, MaxVF); + if (has_single_bit(CandVF + 1)) { NonPowerOf2VF = CandVF; + assert(NonPowerOf2VF != MaxVF && + "Non-power-of-2 VF should not be equal to MaxVF"); + } + } + + unsigned MaxRegVF = MaxVF; + MaxVF = std::min(MaxVF, bit_floor(Operands.size())); + if (MaxVF < MinVF) { + LLVM_DEBUG(dbgs() << "SLP: Vectorization infeasible as MaxVF (" << MaxVF + << ") < " + << "MinVF (" << MinVF << ")\n"); + continue; } unsigned Sz = 1 + Log2_32(MaxVF) - Log2_32(MinVF); @@ -17810,7 +17839,7 @@ bool SLPVectorizerPass::vectorizeStores( std::bind(IsNotVectorized, Size >= MaxRegVF, std::placeholders::_1))); } - if (!AnyProfitableGraph && Size >= MaxRegVF) + if (!AnyProfitableGraph && Size >= MaxRegVF && has_single_bit(Size)) break; } // All values vectorized - exit. @@ -17823,7 +17852,7 @@ bool SLPVectorizerPass::vectorizeStores( (Repeat > 1 && (RepeatChanged || !AnyProfitableGraph))) break; constexpr unsigned StoresLimit = 64; - const unsigned MaxTotalNum = bit_floor(std::min( + const unsigned MaxTotalNum = std::min( Operands.size(), static_cast( End - @@ -17831,8 +17860,13 @@ bool SLPVectorizerPass::vectorizeStores( RangeSizes.begin(), find_if(RangeSizes, std::bind(IsNotVectorized, true, std::placeholders::_1))) + - 1))); - unsigned VF = PowerOf2Ceil(CandidateVFs.front()) * 2; + 1)); + unsigned VF = bit_ceil(CandidateVFs.front()) * 2; + unsigned Limit = + getFloorFullVectorNumberOfElements(*TTI, StoreTy, MaxTotalNum); + CandidateVFs.clear(); + if (bit_floor(Limit) == VF) + CandidateVFs.push_back(Limit); if (VF > MaxTotalNum || VF >= StoresLimit) break; for_each(RangeSizes, [&](std::pair &P) { @@ -17841,7 +17875,6 @@ bool SLPVectorizerPass::vectorizeStores( }); // Last attempt to vectorize max number of elements, if all previous // attempts were unsuccessful because of the cost issues. - CandidateVFs.clear(); CandidateVFs.push_back(VF); } } diff --git a/llvm/test/Transforms/SLPVectorizer/X86/long-full-reg-stores.ll b/llvm/test/Transforms/SLPVectorizer/X86/long-full-reg-stores.ll index b5f993f986c7c..aff66dd7c10ea 100644 --- a/llvm/test/Transforms/SLPVectorizer/X86/long-full-reg-stores.ll +++ b/llvm/test/Transforms/SLPVectorizer/X86/long-full-reg-stores.ll @@ -4,30 +4,16 @@ define void @test(ptr noalias %0, ptr noalias %1) { ; CHECK-LABEL: define void @test( ; CHECK-SAME: ptr noalias [[TMP0:%.*]], ptr noalias [[TMP1:%.*]]) { -; CHECK-NEXT: [[TMP3:%.*]] = getelementptr i8, ptr [[TMP1]], i64 24 -; CHECK-NEXT: [[TMP4:%.*]] = getelementptr i8, ptr [[TMP1]], i64 48 ; CHECK-NEXT: [[TMP5:%.*]] = getelementptr i8, ptr [[TMP1]], i64 8 -; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[TMP1]], i64 16 -; CHECK-NEXT: [[TMP7:%.*]] = getelementptr i8, ptr [[TMP0]], i64 24 -; CHECK-NEXT: [[TMP8:%.*]] = load double, ptr [[TMP7]], align 8 -; CHECK-NEXT: store double [[TMP8]], ptr [[TMP5]], align 8 ; CHECK-NEXT: [[TMP9:%.*]] = getelementptr i8, ptr [[TMP0]], i64 48 -; CHECK-NEXT: [[TMP10:%.*]] = load double, ptr [[TMP9]], align 16 -; CHECK-NEXT: store double [[TMP10]], ptr [[TMP6]], align 16 ; CHECK-NEXT: [[TMP11:%.*]] = getelementptr i8, ptr [[TMP0]], i64 8 -; CHECK-NEXT: [[TMP12:%.*]] = load double, ptr [[TMP11]], align 8 -; CHECK-NEXT: store double [[TMP12]], ptr [[TMP3]], align 8 -; CHECK-NEXT: [[TMP13:%.*]] = getelementptr i8, ptr [[TMP0]], i64 32 -; CHECK-NEXT: [[TMP14:%.*]] = load double, ptr [[TMP13]], align 16 -; CHECK-NEXT: [[TMP15:%.*]] = getelementptr i8, ptr [[TMP1]], i64 32 -; CHECK-NEXT: store double [[TMP14]], ptr [[TMP15]], align 16 -; CHECK-NEXT: [[TMP16:%.*]] = getelementptr i8, ptr [[TMP0]], i64 56 -; CHECK-NEXT: [[TMP17:%.*]] = load double, ptr [[TMP16]], align 8 -; CHECK-NEXT: [[TMP18:%.*]] = getelementptr i8, ptr [[TMP1]], i64 40 -; CHECK-NEXT: store double [[TMP17]], ptr [[TMP18]], align 8 -; CHECK-NEXT: [[TMP19:%.*]] = getelementptr i8, ptr [[TMP0]], i64 16 -; CHECK-NEXT: [[TMP20:%.*]] = load double, ptr [[TMP19]], align 16 -; CHECK-NEXT: store double [[TMP20]], ptr [[TMP4]], align 16 +; CHECK-NEXT: [[TMP6:%.*]] = load <2 x double>, ptr [[TMP9]], align 16 +; CHECK-NEXT: [[TMP7:%.*]] = load <4 x double>, ptr [[TMP11]], align 8 +; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <4 x double> [[TMP7]], <4 x double> poison, <6 x i32> +; CHECK-NEXT: [[TMP12:%.*]] = shufflevector <2 x double> [[TMP6]], <2 x double> poison, <6 x i32> +; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <2 x double> [[TMP6]], <2 x double> poison, <4 x i32> +; CHECK-NEXT: [[TMP13:%.*]] = shufflevector <4 x double> [[TMP7]], <4 x double> [[TMP10]], <6 x i32> +; CHECK-NEXT: store <6 x double> [[TMP13]], ptr [[TMP5]], align 8 ; CHECK-NEXT: [[TMP21:%.*]] = getelementptr i8, ptr [[TMP0]], i64 40 ; CHECK-NEXT: [[TMP22:%.*]] = load double, ptr [[TMP21]], align 8 ; CHECK-NEXT: [[TMP23:%.*]] = getelementptr i8, ptr [[TMP1]], i64 56