diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index f73ad1b15891a..2532edc5d8699 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -1314,6 +1314,22 @@ static bool hasFullVectorsOrPowerOf2(const TargetTransformInfo &TTI, Type *Ty, Sz % NumParts == 0; } +/// Returns number of parts, the type \p VecTy will be split at the codegen +/// phase. If the type is going to be scalarized or does not uses whole +/// registers, returns 1. +static unsigned +getNumberOfParts(const TargetTransformInfo &TTI, VectorType *VecTy, + const unsigned Limit = std::numeric_limits::max()) { + unsigned NumParts = TTI.getNumberOfParts(VecTy); + if (NumParts == 0 || NumParts >= Limit) + return 1; + unsigned Sz = getNumElements(VecTy); + if (NumParts >= Sz || Sz % NumParts != 0 || + !hasFullVectorsOrPowerOf2(TTI, VecTy->getElementType(), Sz / NumParts)) + return 1; + return NumParts; +} + namespace slpvectorizer { /// Bottom Up SLP Vectorizer. @@ -4618,12 +4634,7 @@ BoUpSLP::findReusedOrderedScalars(const BoUpSLP::TreeEntry &TE) { if (!isValidElementType(ScalarTy)) return std::nullopt; auto *VecTy = getWidenedType(ScalarTy, NumScalars); - int NumParts = TTI->getNumberOfParts(VecTy); - if (NumParts == 0 || NumParts >= NumScalars || - VecTy->getNumElements() % NumParts != 0 || - !hasFullVectorsOrPowerOf2(*TTI, VecTy->getElementType(), - VecTy->getNumElements() / NumParts)) - NumParts = 1; + unsigned NumParts = ::getNumberOfParts(*TTI, VecTy, NumScalars); SmallVector ExtractMask; SmallVector Mask; SmallVector> Entries; @@ -5574,8 +5585,8 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) { } } if (Sz == 2 && TE.getVectorFactor() == 4 && - TTI->getNumberOfParts(getWidenedType(TE.Scalars.front()->getType(), - 2 * TE.getVectorFactor())) == 1) + ::getNumberOfParts(*TTI, getWidenedType(TE.Scalars.front()->getType(), + 2 * TE.getVectorFactor())) == 1) return std::nullopt; if (!ShuffleVectorInst::isOneUseSingleSourceMask(TE.ReuseShuffleIndices, Sz)) { @@ -9846,13 +9857,13 @@ void BoUpSLP::transformNodes() { // Do not try to vectorize small splats (less than vector register and // only with the single non-undef element). bool IsSplat = isSplat(Slice); - if (Slices.empty() || !IsSplat || - (VF <= 2 && 2 * std::clamp(TTI->getNumberOfParts(getWidenedType( - Slice.front()->getType(), VF)), - 1U, VF - 1) != - std::clamp(TTI->getNumberOfParts(getWidenedType( - Slice.front()->getType(), 2 * VF)), - 1U, 2 * VF)) || + bool IsTwoRegisterSplat = true; + if (IsSplat && VF == 2) { + unsigned NumRegs2VF = ::getNumberOfParts( + *TTI, getWidenedType(Slice.front()->getType(), 2 * VF)); + IsTwoRegisterSplat = NumRegs2VF == 2; + } + if (Slices.empty() || !IsSplat || !IsTwoRegisterSplat || count(Slice, Slice.front()) == static_cast(isa(Slice.front()) ? VF - 1 : 1)) { @@ -10793,12 +10804,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { } assert(!CommonMask.empty() && "Expected non-empty common mask."); auto *MaskVecTy = getWidenedType(ScalarTy, Mask.size()); - unsigned NumParts = TTI.getNumberOfParts(MaskVecTy); - if (NumParts == 0 || NumParts >= Mask.size() || - MaskVecTy->getNumElements() % NumParts != 0 || - !hasFullVectorsOrPowerOf2(TTI, MaskVecTy->getElementType(), - MaskVecTy->getNumElements() / NumParts)) - NumParts = 1; + unsigned NumParts = ::getNumberOfParts(TTI, MaskVecTy, Mask.size()); unsigned SliceSize = getPartNumElems(Mask.size(), NumParts); const auto *It = find_if(Mask, [](int Idx) { return Idx != PoisonMaskElem; }); @@ -10813,12 +10819,7 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis { } assert(!CommonMask.empty() && "Expected non-empty common mask."); auto *MaskVecTy = getWidenedType(ScalarTy, Mask.size()); - unsigned NumParts = TTI.getNumberOfParts(MaskVecTy); - if (NumParts == 0 || NumParts >= Mask.size() || - MaskVecTy->getNumElements() % NumParts != 0 || - !hasFullVectorsOrPowerOf2(TTI, MaskVecTy->getElementType(), - MaskVecTy->getNumElements() / NumParts)) - NumParts = 1; + unsigned NumParts = ::getNumberOfParts(TTI, MaskVecTy, Mask.size()); unsigned SliceSize = getPartNumElems(Mask.size(), NumParts); const auto *It = find_if(Mask, [](int Idx) { return Idx != PoisonMaskElem; }); @@ -11351,7 +11352,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef VectorizedVals, unsigned const NumElts = SrcVecTy->getNumElements(); unsigned const NumScalars = VL.size(); - unsigned NumOfParts = TTI->getNumberOfParts(SrcVecTy); + unsigned NumOfParts = ::getNumberOfParts(*TTI, SrcVecTy); SmallVector InsertMask(NumElts, PoisonMaskElem); unsigned OffsetBeg = *getElementIndex(VL.front()); @@ -14862,12 +14863,7 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Type *ScalarTy, SmallVector> Entries; Type *OrigScalarTy = GatheredScalars.front()->getType(); auto *VecTy = getWidenedType(ScalarTy, GatheredScalars.size()); - unsigned NumParts = TTI->getNumberOfParts(VecTy); - if (NumParts == 0 || NumParts >= GatheredScalars.size() || - VecTy->getNumElements() % NumParts != 0 || - !hasFullVectorsOrPowerOf2(*TTI, VecTy->getElementType(), - VecTy->getNumElements() / NumParts)) - NumParts = 1; + unsigned NumParts = ::getNumberOfParts(*TTI, VecTy, GatheredScalars.size()); if (!all_of(GatheredScalars, IsaPred)) { // Check for gathered extracts. bool Resized = false; @@ -14899,12 +14895,8 @@ ResTy BoUpSLP::processBuildVector(const TreeEntry *E, Type *ScalarTy, Resized = true; GatheredScalars.append(VF - GatheredScalars.size(), PoisonValue::get(OrigScalarTy)); - NumParts = TTI->getNumberOfParts(getWidenedType(OrigScalarTy, VF)); - if (NumParts == 0 || NumParts >= GatheredScalars.size() || - VecTy->getNumElements() % NumParts != 0 || - !hasFullVectorsOrPowerOf2(*TTI, VecTy->getElementType(), - VecTy->getNumElements() / NumParts)) - NumParts = 1; + NumParts = + ::getNumberOfParts(*TTI, getWidenedType(OrigScalarTy, VF), VF); } } } @@ -17049,10 +17041,10 @@ void BoUpSLP::optimizeGatherSequence() { // Check if the last undefs actually change the final number of used vector // registers. return SM1.size() - LastUndefsCnt > 1 && - TTI->getNumberOfParts(SI1->getType()) == - TTI->getNumberOfParts( - getWidenedType(SI1->getType()->getElementType(), - SM1.size() - LastUndefsCnt)); + ::getNumberOfParts(*TTI, SI1->getType()) == + ::getNumberOfParts( + *TTI, getWidenedType(SI1->getType()->getElementType(), + SM1.size() - LastUndefsCnt)); }; // Perform O(N^2) search over the gather/shuffle sequences and merge identical // instructions. TODO: We can further optimize this scan if we split the @@ -17829,9 +17821,12 @@ bool BoUpSLP::collectValuesToDemote( const unsigned VF = E.Scalars.size(); Type *OrigScalarTy = E.Scalars.front()->getType(); if (UniqueBases.size() <= 2 || - TTI->getNumberOfParts(getWidenedType(OrigScalarTy, VF)) == - TTI->getNumberOfParts(getWidenedType( - IntegerType::get(OrigScalarTy->getContext(), BitWidth), VF))) + ::getNumberOfParts(*TTI, getWidenedType(OrigScalarTy, VF)) == + ::getNumberOfParts( + *TTI, + getWidenedType( + IntegerType::get(OrigScalarTy->getContext(), BitWidth), + VF))) ToDemote.push_back(E.Idx); } return Res; @@ -18241,8 +18236,8 @@ void BoUpSLP::computeMinimumValueSizes() { [&](Value *V) { return AnalyzedMinBWVals.contains(V); })) return 0u; - unsigned NumParts = TTI->getNumberOfParts( - getWidenedType(TreeRootIT, VF * ScalarTyNumElements)); + unsigned NumParts = ::getNumberOfParts( + *TTI, getWidenedType(TreeRootIT, VF * ScalarTyNumElements)); // The maximum bit width required to represent all the values that can be // demoted without loss of precision. It would be safe to truncate the roots @@ -18302,8 +18297,10 @@ void BoUpSLP::computeMinimumValueSizes() { // use - ignore it. if (NumParts > 1 && NumParts == - TTI->getNumberOfParts(getWidenedType( - IntegerType::get(F->getContext(), bit_ceil(MaxBitWidth)), VF))) + ::getNumberOfParts( + *TTI, getWidenedType(IntegerType::get(F->getContext(), + bit_ceil(MaxBitWidth)), + VF))) return 0u; unsigned Opcode = E.getOpcode(); @@ -20086,14 +20083,14 @@ class HorizontalReduction { ReduxWidth = getFloorFullVectorNumberOfElements(TTI, ScalarTy, ReduxWidth); VectorType *Tp = getWidenedType(ScalarTy, ReduxWidth); - NumParts = TTI.getNumberOfParts(Tp); + NumParts = ::getNumberOfParts(TTI, Tp); NumRegs = TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true, Tp)); while (NumParts > NumRegs) { assert(ReduxWidth > 0 && "ReduxWidth is unexpectedly 0."); ReduxWidth = bit_floor(ReduxWidth - 1); VectorType *Tp = getWidenedType(ScalarTy, ReduxWidth); - NumParts = TTI.getNumberOfParts(Tp); + NumParts = ::getNumberOfParts(TTI, Tp); NumRegs = TTI.getNumberOfRegisters(TTI.getRegisterClassForType(true, Tp)); } diff --git a/llvm/test/Transforms/SLPVectorizer/RISCV/partial-vec-invalid-cost.ll b/llvm/test/Transforms/SLPVectorizer/RISCV/partial-vec-invalid-cost.ll index 6388cc2dedc73..085d7a64fc9ac 100644 --- a/llvm/test/Transforms/SLPVectorizer/RISCV/partial-vec-invalid-cost.ll +++ b/llvm/test/Transforms/SLPVectorizer/RISCV/partial-vec-invalid-cost.ll @@ -7,9 +7,17 @@ define void @partial_vec_invalid_cost() #0 { ; CHECK-LABEL: define void @partial_vec_invalid_cost( ; CHECK-SAME: ) #[[ATTR0:[0-9]+]] { ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> zeroinitializer) +; CHECK-NEXT: [[LSHR_1:%.*]] = lshr i96 0, 0 +; CHECK-NEXT: [[LSHR_2:%.*]] = lshr i96 0, 0 +; CHECK-NEXT: [[TRUNC_I96_1:%.*]] = trunc i96 [[LSHR_1]] to i32 +; CHECK-NEXT: [[TRUNC_I96_2:%.*]] = trunc i96 [[LSHR_2]] to i32 +; CHECK-NEXT: [[TRUNC_I96_3:%.*]] = trunc i96 0 to i32 +; CHECK-NEXT: [[TRUNC_I96_4:%.*]] = trunc i96 0 to i32 ; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.vector.reduce.or.v4i32(<4 x i32> zeroinitializer) -; CHECK-NEXT: [[OP_RDX3:%.*]] = or i32 [[TMP0]], [[TMP1]] +; CHECK-NEXT: [[OP_RDX:%.*]] = or i32 [[TMP1]], [[TRUNC_I96_1]] +; CHECK-NEXT: [[OP_RDX1:%.*]] = or i32 [[TRUNC_I96_2]], [[TRUNC_I96_3]] +; CHECK-NEXT: [[OP_RDX2:%.*]] = or i32 [[OP_RDX]], [[OP_RDX1]] +; CHECK-NEXT: [[OP_RDX3:%.*]] = or i32 [[OP_RDX2]], [[TRUNC_I96_4]] ; CHECK-NEXT: [[STORE_THIS:%.*]] = zext i32 [[OP_RDX3]] to i96 ; CHECK-NEXT: store i96 [[STORE_THIS]], ptr null, align 16 ; CHECK-NEXT: ret void