@@ -2242,8 +2242,29 @@ class BoUpSLP {
22422242 /// may not be necessary.
22432243 bool isLoadCombineCandidate(ArrayRef<Value *> Stores) const;
22442244 bool isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
2245- Align Alignment, const int64_t Diff, Value *Ptr0,
2246- Value *PtrN, StridedPtrInfo &SPtrInfo) const;
2245+ Align Alignment, int64_t Diff, size_t VecSz) const;
2246+ /// Given a set of pointers, check if they can be rearranged as follows (%s is
2247+ /// a constant):
2248+ /// %b + 0 * %s + 0
2249+ /// %b + 0 * %s + 1
2250+ /// %b + 0 * %s + 2
2251+ /// ...
2252+ /// %b + 0 * %s + w
2253+ ///
2254+ /// %b + 1 * %s + 0
2255+ /// %b + 1 * %s + 1
2256+ /// %b + 1 * %s + 2
2257+ /// ...
2258+ /// %b + 1 * %s + w
2259+ /// ...
2260+ ///
2261+ /// If the pointers can be rearanged in the above pattern, it means that the
2262+ /// memory can be accessed with a strided loads of width `w` and stride `%s`.
2263+ bool analyzeConstantStrideCandidate(ArrayRef<Value *> PointerOps,
2264+ Type *ElemTy, Align CommonAlignment,
2265+ SmallVectorImpl<unsigned> &SortedIndices,
2266+ int64_t Diff, Value *Ptr0, Value *PtrN,
2267+ StridedPtrInfo &SPtrInfo) const;
22472268
22482269 /// Return true if an array of scalar loads can be replaced with a strided
22492270 /// load (with run-time stride).
@@ -6844,12 +6865,7 @@ isMaskedLoadCompress(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
68446865/// current graph (for masked gathers extra extractelement instructions
68456866/// might be required).
68466867bool BoUpSLP::isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
6847- Align Alignment, const int64_t Diff, Value *Ptr0,
6848- Value *PtrN, StridedPtrInfo &SPtrInfo) const {
6849- const size_t Sz = PointerOps.size();
6850- if (Diff % (Sz - 1) != 0)
6851- return false;
6852-
6868+ Align Alignment, int64_t Diff, size_t VecSz) const {
68536869 // Try to generate strided load node.
68546870 auto IsAnyPointerUsedOutGraph = any_of(PointerOps, [&](Value *V) {
68556871 return isa<Instruction>(V) && any_of(V->users(), [&](User *U) {
@@ -6858,41 +6874,109 @@ bool BoUpSLP::isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
68586874 });
68596875
68606876 const uint64_t AbsoluteDiff = std::abs(Diff);
6861- auto *VecTy = getWidenedType(ScalarTy, Sz );
6877+ auto *VecTy = getWidenedType(ScalarTy, VecSz );
68626878 if (IsAnyPointerUsedOutGraph ||
6863- (AbsoluteDiff > Sz &&
6864- (Sz > MinProfitableStridedLoads ||
6865- (AbsoluteDiff <= MaxProfitableLoadStride * Sz &&
6866- AbsoluteDiff % Sz == 0 && has_single_bit(AbsoluteDiff / Sz )))) ||
6867- Diff == -(static_cast<int64_t>(Sz ) - 1)) {
6868- int64_t Stride = Diff / static_cast<int64_t>(Sz - 1);
6869- if (Diff != Stride * static_cast<int64_t>(Sz - 1))
6879+ (AbsoluteDiff > VecSz &&
6880+ (VecSz > MinProfitableStridedLoads ||
6881+ (AbsoluteDiff <= MaxProfitableLoadStride * VecSz &&
6882+ AbsoluteDiff % VecSz == 0 && has_single_bit(AbsoluteDiff / VecSz )))) ||
6883+ Diff == -(static_cast<int64_t>(VecSz ) - 1)) {
6884+ int64_t Stride = Diff / static_cast<int64_t>(VecSz - 1);
6885+ if (Diff != Stride * static_cast<int64_t>(VecSz - 1))
68706886 return false;
68716887 if (!TTI->isLegalStridedLoadStore(VecTy, Alignment))
68726888 return false;
6889+ }
6890+ return true;
6891+ }
6892+
6893+ bool BoUpSLP::analyzeConstantStrideCandidate(
6894+ ArrayRef<Value *> PointerOps, Type *ElemTy, Align CommonAlignment,
6895+ SmallVectorImpl<unsigned> &SortedIndices, int64_t Diff, Value *Ptr0,
6896+ Value *PtrN, StridedPtrInfo &SPtrInfo) const {
6897+ const unsigned Sz = PointerOps.size();
6898+ SmallVector<int64_t> SortedOffsetsFromBase;
6899+ SortedOffsetsFromBase.resize(Sz);
6900+ for (unsigned I : seq<unsigned>(Sz)) {
6901+ Value *Ptr =
6902+ SortedIndices.empty() ? PointerOps[I] : PointerOps[SortedIndices[I]];
6903+ SortedOffsetsFromBase[I] =
6904+ *getPointersDiff(ElemTy, Ptr0, ElemTy, Ptr, *DL, *SE);
6905+ }
6906+ assert(SortedOffsetsFromBase.size() > 1 &&
6907+ "Trying to generate strided load for less than 2 loads");
6908+ //
6909+ // Find where the first group ends.
6910+ int64_t StrideWithinGroup =
6911+ SortedOffsetsFromBase[1] - SortedOffsetsFromBase[0];
6912+ unsigned GroupSize = 1;
6913+ for (; GroupSize != SortedOffsetsFromBase.size(); ++GroupSize) {
6914+ if (SortedOffsetsFromBase[GroupSize] -
6915+ SortedOffsetsFromBase[GroupSize - 1] !=
6916+ StrideWithinGroup)
6917+ break;
6918+ }
6919+ unsigned VecSz = Sz;
6920+ Type *ScalarTy = ElemTy;
6921+ int64_t StrideIntVal = StrideWithinGroup;
6922+ FixedVectorType *StridedLoadTy = getWidenedType(ScalarTy, VecSz);
6923+
6924+ if (Sz != GroupSize) {
6925+ if (Sz % GroupSize != 0)
6926+ return false;
6927+ VecSz = Sz / GroupSize;
6928+
6929+ if (StrideWithinGroup != 1)
6930+ return false;
6931+ unsigned VecSz = Sz / GroupSize;
6932+ ScalarTy = Type::getIntNTy(SE->getContext(),
6933+ DL->getTypeSizeInBits(ElemTy).getFixedValue() *
6934+ GroupSize);
6935+ StridedLoadTy = getWidenedType(ScalarTy, VecSz);
6936+ if (!TTI->isTypeLegal(StridedLoadTy) ||
6937+ !TTI->isLegalStridedLoadStore(StridedLoadTy, CommonAlignment))
6938+ return false;
68736939
6874- // Iterate through all pointers and check if all distances are
6875- // unique multiple of Dist.
6876- SmallSet<int64_t, 4> Dists;
6877- for (Value *Ptr : PointerOps) {
6878- int64_t Dist = 0;
6879- if (Ptr == PtrN)
6880- Dist = Diff;
6881- else if (Ptr != Ptr0)
6882- Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, *DL, *SE);
6883- // If the strides are not the same or repeated, we can't
6884- // vectorize.
6885- if (((Dist / Stride) * Stride) != Dist || !Dists.insert(Dist).second)
6940+ unsigned PrevGroupStartIdx = 0;
6941+ unsigned CurrentGroupStartIdx = GroupSize;
6942+ int64_t StrideBetweenGroups =
6943+ SortedOffsetsFromBase[GroupSize] - SortedOffsetsFromBase[0];
6944+ StrideIntVal = StrideBetweenGroups;
6945+ while (CurrentGroupStartIdx != Sz) {
6946+ if (SortedOffsetsFromBase[CurrentGroupStartIdx] -
6947+ SortedOffsetsFromBase[PrevGroupStartIdx] !=
6948+ StrideBetweenGroups)
68866949 break;
6950+ PrevGroupStartIdx = CurrentGroupStartIdx;
6951+ CurrentGroupStartIdx += GroupSize;
68876952 }
6888- if (Dists.size() == Sz) {
6889- Type *StrideTy = DL->getIndexType(Ptr0->getType());
6890- SPtrInfo.StrideVal = ConstantInt::get(StrideTy, Stride);
6891- SPtrInfo.Ty = getWidenedType(ScalarTy, Sz);
6892- return true;
6953+ if (CurrentGroupStartIdx != Sz)
6954+ return false;
6955+
6956+ auto CheckGroup = [&](unsigned StartIdx, unsigned GroupSize0,
6957+ int64_t StrideWithinGroup) -> bool {
6958+ unsigned GroupEndIdx = StartIdx + 1;
6959+ for (; GroupEndIdx != Sz; ++GroupEndIdx) {
6960+ if (SortedOffsetsFromBase[GroupEndIdx] -
6961+ SortedOffsetsFromBase[GroupEndIdx - 1] !=
6962+ StrideWithinGroup)
6963+ break;
6964+ }
6965+ return GroupEndIdx - StartIdx == GroupSize0;
6966+ };
6967+ for (unsigned I = 0; I < Sz; I += GroupSize) {
6968+ if (!CheckGroup(I, GroupSize, StrideWithinGroup))
6969+ return false;
68936970 }
68946971 }
6895- return false;
6972+
6973+ if (!isStridedLoad(PointerOps, ScalarTy, CommonAlignment, Diff, VecSz))
6974+ return false;
6975+
6976+ Type *StrideTy = DL->getIndexType(Ptr0->getType());
6977+ SPtrInfo.StrideVal = ConstantInt::get(StrideTy, StrideIntVal);
6978+ SPtrInfo.Ty = StridedLoadTy;
6979+ return true;
68966980}
68976981
68986982bool BoUpSLP::analyzeRtStrideCandidate(ArrayRef<Value *> PointerOps,
@@ -6990,8 +7074,8 @@ BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads(
69907074 Align Alignment =
69917075 cast<LoadInst>(Order.empty() ? VL.front() : VL[Order.front()])
69927076 ->getAlign();
6993- if (isStridedLoad (PointerOps, ScalarTy, Alignment, *Diff, Ptr0, PtrN ,
6994- SPtrInfo))
7077+ if (analyzeConstantStrideCandidate (PointerOps, ScalarTy, Alignment, Order ,
7078+ *Diff, Ptr0, PtrN, SPtrInfo))
69957079 return LoadsState::StridedVectorize;
69967080 }
69977081 if (!TTI->isLegalMaskedGather(VecTy, CommonAlignment) ||
@@ -14902,11 +14986,19 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
1490214986 }
1490314987 break;
1490414988 case TreeEntry::StridedVectorize: {
14989+ const StridedPtrInfo &SPtrInfo = TreeEntryToStridedPtrInfoMap.at(E);
14990+ FixedVectorType *StridedLoadTy = SPtrInfo.Ty;
14991+ assert(StridedLoadTy && "Missing StridedPoinerInfo for tree entry.");
1490514992 Align CommonAlignment =
1490614993 computeCommonAlignment<LoadInst>(UniqueValues.getArrayRef());
1490714994 VecLdCost = TTI->getStridedMemoryOpCost(
14908- Instruction::Load, VecTy , LI0->getPointerOperand(),
14995+ Instruction::Load, StridedLoadTy , LI0->getPointerOperand(),
1490914996 /*VariableMask=*/false, CommonAlignment, CostKind);
14997+ if (StridedLoadTy != VecTy)
14998+ VecLdCost +=
14999+ TTI->getCastInstrCost(Instruction::BitCast, VecTy, StridedLoadTy,
15000+ getCastContextHint(*E), CostKind);
15001+
1491015002 break;
1491115003 }
1491215004 case TreeEntry::CompressVectorize: {
@@ -19670,6 +19762,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
1967019762 ? NewLI
1967119763 : ::propagateMetadata(NewLI, E->Scalars);
1967219764
19765+ if (StridedLoadTy != VecTy)
19766+ V = Builder.CreateBitOrPointerCast(V, VecTy);
1967319767 V = FinalShuffle(V, E);
1967419768 E->VectorizedValue = V;
1967519769 ++NumVectorInstructions;
0 commit comments