@@ -2242,8 +2242,29 @@ class BoUpSLP {
22422242 /// may not be necessary.
22432243 bool isLoadCombineCandidate(ArrayRef<Value *> Stores) const;
22442244 bool isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
2245- Align Alignment, const int64_t Diff, Value *Ptr0,
2246- Value *PtrN, StridedPtrInfo &SPtrInfo) const;
2245+ Align Alignment, int64_t Diff, size_t VecSz) const;
2246+ /// Given a set of pointers, check if they can be rearranged as follows (%s is
2247+ /// a constant):
2248+ /// %b + 0 * %s + 0
2249+ /// %b + 0 * %s + 1
2250+ /// %b + 0 * %s + 2
2251+ /// ...
2252+ /// %b + 0 * %s + w
2253+ ///
2254+ /// %b + 1 * %s + 0
2255+ /// %b + 1 * %s + 1
2256+ /// %b + 1 * %s + 2
2257+ /// ...
2258+ /// %b + 1 * %s + w
2259+ /// ...
2260+ ///
2261+ /// If the pointers can be rearanged in the above pattern, it means that the
2262+ /// memory can be accessed with a strided loads of width `w` and stride `%s`.
2263+ bool analyzeConstantStrideCandidate(ArrayRef<Value *> PointerOps,
2264+ Type *ElemTy, Align CommonAlignment,
2265+ SmallVectorImpl<unsigned> &SortedIndices,
2266+ int64_t Diff, Value *Ptr0, Value *PtrN,
2267+ StridedPtrInfo &SPtrInfo) const;
22472268
22482269 /// Checks if the given array of loads can be represented as a vectorized,
22492270 /// scatter or just simple gather.
@@ -6824,12 +6845,7 @@ isMaskedLoadCompress(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
68246845/// current graph (for masked gathers extra extractelement instructions
68256846/// might be required).
68266847bool BoUpSLP::isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
6827- Align Alignment, const int64_t Diff, Value *Ptr0,
6828- Value *PtrN, StridedPtrInfo &SPtrInfo) const {
6829- const size_t Sz = PointerOps.size();
6830- if (Diff % (Sz - 1) != 0)
6831- return false;
6832-
6848+ Align Alignment, int64_t Diff, size_t VecSz) const {
68336849 // Try to generate strided load node.
68346850 auto IsAnyPointerUsedOutGraph = any_of(PointerOps, [&](Value *V) {
68356851 return isa<Instruction>(V) && any_of(V->users(), [&](User *U) {
@@ -6838,41 +6854,109 @@ bool BoUpSLP::isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
68386854 });
68396855
68406856 const uint64_t AbsoluteDiff = std::abs(Diff);
6841- auto *VecTy = getWidenedType(ScalarTy, Sz );
6857+ auto *VecTy = getWidenedType(ScalarTy, VecSz );
68426858 if (IsAnyPointerUsedOutGraph ||
6843- (AbsoluteDiff > Sz &&
6844- (Sz > MinProfitableStridedLoads ||
6845- (AbsoluteDiff <= MaxProfitableLoadStride * Sz &&
6846- AbsoluteDiff % Sz == 0 && has_single_bit(AbsoluteDiff / Sz )))) ||
6847- Diff == -(static_cast<int64_t>(Sz ) - 1)) {
6848- int64_t Stride = Diff / static_cast<int64_t>(Sz - 1);
6849- if (Diff != Stride * static_cast<int64_t>(Sz - 1))
6859+ (AbsoluteDiff > VecSz &&
6860+ (VecSz > MinProfitableStridedLoads ||
6861+ (AbsoluteDiff <= MaxProfitableLoadStride * VecSz &&
6862+ AbsoluteDiff % VecSz == 0 && has_single_bit(AbsoluteDiff / VecSz )))) ||
6863+ Diff == -(static_cast<int64_t>(VecSz ) - 1)) {
6864+ int64_t Stride = Diff / static_cast<int64_t>(VecSz - 1);
6865+ if (Diff != Stride * static_cast<int64_t>(VecSz - 1))
68506866 return false;
68516867 if (!TTI->isLegalStridedLoadStore(VecTy, Alignment))
68526868 return false;
6869+ }
6870+ return true;
6871+ }
6872+
6873+ bool BoUpSLP::analyzeConstantStrideCandidate(
6874+ ArrayRef<Value *> PointerOps, Type *ElemTy, Align CommonAlignment,
6875+ SmallVectorImpl<unsigned> &SortedIndices, int64_t Diff, Value *Ptr0,
6876+ Value *PtrN, StridedPtrInfo &SPtrInfo) const {
6877+ const unsigned Sz = PointerOps.size();
6878+ SmallVector<int64_t> SortedOffsetsFromBase;
6879+ SortedOffsetsFromBase.resize(Sz);
6880+ for (unsigned I : seq<unsigned>(Sz)) {
6881+ Value *Ptr =
6882+ SortedIndices.empty() ? PointerOps[I] : PointerOps[SortedIndices[I]];
6883+ SortedOffsetsFromBase[I] =
6884+ *getPointersDiff(ElemTy, Ptr0, ElemTy, Ptr, *DL, *SE);
6885+ }
6886+ assert(SortedOffsetsFromBase.size() > 1 &&
6887+ "Trying to generate strided load for less than 2 loads");
6888+ //
6889+ // Find where the first group ends.
6890+ int64_t StrideWithinGroup =
6891+ SortedOffsetsFromBase[1] - SortedOffsetsFromBase[0];
6892+ unsigned GroupSize = 1;
6893+ for (; GroupSize != SortedOffsetsFromBase.size(); ++GroupSize) {
6894+ if (SortedOffsetsFromBase[GroupSize] -
6895+ SortedOffsetsFromBase[GroupSize - 1] !=
6896+ StrideWithinGroup)
6897+ break;
6898+ }
6899+ unsigned VecSz = Sz;
6900+ Type *ScalarTy = ElemTy;
6901+ int64_t StrideIntVal = StrideWithinGroup;
6902+ FixedVectorType *StridedLoadTy = getWidenedType(ScalarTy, VecSz);
6903+
6904+ if (Sz != GroupSize) {
6905+ if (Sz % GroupSize != 0)
6906+ return false;
6907+ VecSz = Sz / GroupSize;
6908+
6909+ if (StrideWithinGroup != 1)
6910+ return false;
6911+ unsigned VecSz = Sz / GroupSize;
6912+ ScalarTy = Type::getIntNTy(SE->getContext(),
6913+ DL->getTypeSizeInBits(ElemTy).getFixedValue() *
6914+ GroupSize);
6915+ StridedLoadTy = getWidenedType(ScalarTy, VecSz);
6916+ if (!TTI->isTypeLegal(StridedLoadTy) ||
6917+ !TTI->isLegalStridedLoadStore(StridedLoadTy, CommonAlignment))
6918+ return false;
68536919
6854- // Iterate through all pointers and check if all distances are
6855- // unique multiple of Dist.
6856- SmallSet<int64_t, 4> Dists;
6857- for (Value *Ptr : PointerOps) {
6858- int64_t Dist = 0;
6859- if (Ptr == PtrN)
6860- Dist = Diff;
6861- else if (Ptr != Ptr0)
6862- Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, *DL, *SE);
6863- // If the strides are not the same or repeated, we can't
6864- // vectorize.
6865- if (((Dist / Stride) * Stride) != Dist || !Dists.insert(Dist).second)
6920+ unsigned PrevGroupStartIdx = 0;
6921+ unsigned CurrentGroupStartIdx = GroupSize;
6922+ int64_t StrideBetweenGroups =
6923+ SortedOffsetsFromBase[GroupSize] - SortedOffsetsFromBase[0];
6924+ StrideIntVal = StrideBetweenGroups;
6925+ while (CurrentGroupStartIdx != Sz) {
6926+ if (SortedOffsetsFromBase[CurrentGroupStartIdx] -
6927+ SortedOffsetsFromBase[PrevGroupStartIdx] !=
6928+ StrideBetweenGroups)
68666929 break;
6930+ PrevGroupStartIdx = CurrentGroupStartIdx;
6931+ CurrentGroupStartIdx += GroupSize;
68676932 }
6868- if (Dists.size() == Sz) {
6869- Type *StrideTy = DL->getIndexType(Ptr0->getType());
6870- SPtrInfo.StrideVal = ConstantInt::get(StrideTy, Stride);
6871- SPtrInfo.Ty = getWidenedType(ScalarTy, Sz);
6872- return true;
6933+ if (CurrentGroupStartIdx != Sz)
6934+ return false;
6935+
6936+ auto CheckGroup = [&](unsigned StartIdx, unsigned GroupSize0,
6937+ int64_t StrideWithinGroup) -> bool {
6938+ unsigned GroupEndIdx = StartIdx + 1;
6939+ for (; GroupEndIdx != Sz; ++GroupEndIdx) {
6940+ if (SortedOffsetsFromBase[GroupEndIdx] -
6941+ SortedOffsetsFromBase[GroupEndIdx - 1] !=
6942+ StrideWithinGroup)
6943+ break;
6944+ }
6945+ return GroupEndIdx - StartIdx == GroupSize0;
6946+ };
6947+ for (unsigned I = 0; I < Sz; I += GroupSize) {
6948+ if (!CheckGroup(I, GroupSize, StrideWithinGroup))
6949+ return false;
68736950 }
68746951 }
6875- return false;
6952+
6953+ if (!isStridedLoad(PointerOps, ScalarTy, CommonAlignment, Diff, VecSz))
6954+ return false;
6955+
6956+ Type *StrideTy = DL->getIndexType(Ptr0->getType());
6957+ SPtrInfo.StrideVal = ConstantInt::get(StrideTy, StrideIntVal);
6958+ SPtrInfo.Ty = StridedLoadTy;
6959+ return true;
68766960}
68776961
68786962BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads(
@@ -6958,8 +7042,8 @@ BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads(
69587042 Align Alignment =
69597043 cast<LoadInst>(Order.empty() ? VL.front() : VL[Order.front()])
69607044 ->getAlign();
6961- if (isStridedLoad (PointerOps, ScalarTy, Alignment, *Diff, Ptr0, PtrN ,
6962- SPtrInfo))
7045+ if (analyzeConstantStrideCandidate (PointerOps, ScalarTy, Alignment, Order ,
7046+ *Diff, Ptr0, PtrN, SPtrInfo))
69637047 return LoadsState::StridedVectorize;
69647048 }
69657049 if (!TTI->isLegalMaskedGather(VecTy, CommonAlignment) ||
@@ -14865,11 +14949,19 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
1486514949 }
1486614950 break;
1486714951 case TreeEntry::StridedVectorize: {
14952+ const StridedPtrInfo &SPtrInfo = TreeEntryToStridedPtrInfoMap.at(E);
14953+ FixedVectorType *StridedLoadTy = SPtrInfo.Ty;
14954+ assert(StridedLoadTy && "Missing StridedPoinerInfo for tree entry.");
1486814955 Align CommonAlignment =
1486914956 computeCommonAlignment<LoadInst>(UniqueValues.getArrayRef());
1487014957 VecLdCost = TTI->getStridedMemoryOpCost(
14871- Instruction::Load, VecTy , LI0->getPointerOperand(),
14958+ Instruction::Load, StridedLoadTy , LI0->getPointerOperand(),
1487214959 /*VariableMask=*/false, CommonAlignment, CostKind);
14960+ if (StridedLoadTy != VecTy)
14961+ VecLdCost +=
14962+ TTI->getCastInstrCost(Instruction::BitCast, VecTy, StridedLoadTy,
14963+ getCastContextHint(*E), CostKind);
14964+
1487314965 break;
1487414966 }
1487514967 case TreeEntry::CompressVectorize: {
@@ -19633,6 +19725,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
1963319725 ? NewLI
1963419726 : ::propagateMetadata(NewLI, E->Scalars);
1963519727
19728+ if (StridedLoadTy != VecTy)
19729+ V = Builder.CreateBitOrPointerCast(V, VecTy);
1963619730 V = FinalShuffle(V, E);
1963719731 E->VectorizedValue = V;
1963819732 ++NumVectorInstructions;
0 commit comments