@@ -2242,8 +2242,29 @@ class BoUpSLP {
22422242 /// may not be necessary.
22432243 bool isLoadCombineCandidate(ArrayRef<Value *> Stores) const;
22442244 bool isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
2245- Align Alignment, const int64_t Diff, Value *Ptr0,
2246- Value *PtrN, StridedPtrInfo &SPtrInfo) const;
2245+ Align Alignment, int64_t Diff, size_t VecSz) const;
2246+ /// Given a set of pointers, check if they can be rearranged as follows (%s is
2247+ /// a constant):
2248+ /// %b + 0 * %s + 0
2249+ /// %b + 0 * %s + 1
2250+ /// %b + 0 * %s + 2
2251+ /// ...
2252+ /// %b + 0 * %s + w
2253+ ///
2254+ /// %b + 1 * %s + 0
2255+ /// %b + 1 * %s + 1
2256+ /// %b + 1 * %s + 2
2257+ /// ...
2258+ /// %b + 1 * %s + w
2259+ /// ...
2260+ ///
2261+ /// If the pointers can be rearanged in the above pattern, it means that the
2262+ /// memory can be accessed with a strided loads of width `w` and stride `%s`.
2263+ bool analyzeConstantStrideCandidate(ArrayRef<Value *> PointerOps,
2264+ Type *ElemTy, Align CommonAlignment,
2265+ SmallVectorImpl<unsigned> &SortedIndices,
2266+ int64_t Diff, Value *Ptr0, Value *PtrN,
2267+ StridedPtrInfo &SPtrInfo) const;
22472268
22482269 /// Return true if an array of scalar loads can be replaced with a strided
22492270 /// load (with run-time stride).
@@ -6849,12 +6870,7 @@ isMaskedLoadCompress(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
68496870/// current graph (for masked gathers extra extractelement instructions
68506871/// might be required).
68516872bool BoUpSLP::isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
6852- Align Alignment, const int64_t Diff, Value *Ptr0,
6853- Value *PtrN, StridedPtrInfo &SPtrInfo) const {
6854- const size_t Sz = PointerOps.size();
6855- if (Diff % (Sz - 1) != 0)
6856- return false;
6857-
6873+ Align Alignment, int64_t Diff, size_t VecSz) const {
68586874 // Try to generate strided load node.
68596875 auto IsAnyPointerUsedOutGraph = any_of(PointerOps, [&](Value *V) {
68606876 return isa<Instruction>(V) && any_of(V->users(), [&](User *U) {
@@ -6863,41 +6879,109 @@ bool BoUpSLP::isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
68636879 });
68646880
68656881 const uint64_t AbsoluteDiff = std::abs(Diff);
6866- auto *VecTy = getWidenedType(ScalarTy, Sz );
6882+ auto *VecTy = getWidenedType(ScalarTy, VecSz );
68676883 if (IsAnyPointerUsedOutGraph ||
6868- (AbsoluteDiff > Sz &&
6869- (Sz > MinProfitableStridedLoads ||
6870- (AbsoluteDiff <= MaxProfitableLoadStride * Sz &&
6871- AbsoluteDiff % Sz == 0 && has_single_bit(AbsoluteDiff / Sz )))) ||
6872- Diff == -(static_cast<int64_t>(Sz ) - 1)) {
6873- int64_t Stride = Diff / static_cast<int64_t>(Sz - 1);
6874- if (Diff != Stride * static_cast<int64_t>(Sz - 1))
6884+ (AbsoluteDiff > VecSz &&
6885+ (VecSz > MinProfitableStridedLoads ||
6886+ (AbsoluteDiff <= MaxProfitableLoadStride * VecSz &&
6887+ AbsoluteDiff % VecSz == 0 && has_single_bit(AbsoluteDiff / VecSz )))) ||
6888+ Diff == -(static_cast<int64_t>(VecSz ) - 1)) {
6889+ int64_t Stride = Diff / static_cast<int64_t>(VecSz - 1);
6890+ if (Diff != Stride * static_cast<int64_t>(VecSz - 1))
68756891 return false;
68766892 if (!TTI->isLegalStridedLoadStore(VecTy, Alignment))
68776893 return false;
6894+ }
6895+ return true;
6896+ }
6897+
6898+ bool BoUpSLP::analyzeConstantStrideCandidate(
6899+ ArrayRef<Value *> PointerOps, Type *ElemTy, Align CommonAlignment,
6900+ SmallVectorImpl<unsigned> &SortedIndices, int64_t Diff, Value *Ptr0,
6901+ Value *PtrN, StridedPtrInfo &SPtrInfo) const {
6902+ const unsigned Sz = PointerOps.size();
6903+ SmallVector<int64_t> SortedOffsetsFromBase;
6904+ SortedOffsetsFromBase.resize(Sz);
6905+ for (unsigned I : seq<unsigned>(Sz)) {
6906+ Value *Ptr =
6907+ SortedIndices.empty() ? PointerOps[I] : PointerOps[SortedIndices[I]];
6908+ SortedOffsetsFromBase[I] =
6909+ *getPointersDiff(ElemTy, Ptr0, ElemTy, Ptr, *DL, *SE);
6910+ }
6911+ assert(SortedOffsetsFromBase.size() > 1 &&
6912+ "Trying to generate strided load for less than 2 loads");
6913+ //
6914+ // Find where the first group ends.
6915+ int64_t StrideWithinGroup =
6916+ SortedOffsetsFromBase[1] - SortedOffsetsFromBase[0];
6917+ unsigned GroupSize = 1;
6918+ for (; GroupSize != SortedOffsetsFromBase.size(); ++GroupSize) {
6919+ if (SortedOffsetsFromBase[GroupSize] -
6920+ SortedOffsetsFromBase[GroupSize - 1] !=
6921+ StrideWithinGroup)
6922+ break;
6923+ }
6924+ unsigned VecSz = Sz;
6925+ Type *ScalarTy = ElemTy;
6926+ int64_t StrideIntVal = StrideWithinGroup;
6927+ FixedVectorType *StridedLoadTy = getWidenedType(ScalarTy, VecSz);
6928+
6929+ if (Sz != GroupSize) {
6930+ if (Sz % GroupSize != 0)
6931+ return false;
6932+ VecSz = Sz / GroupSize;
6933+
6934+ if (StrideWithinGroup != 1)
6935+ return false;
6936+ unsigned VecSz = Sz / GroupSize;
6937+ ScalarTy = Type::getIntNTy(SE->getContext(),
6938+ DL->getTypeSizeInBits(ElemTy).getFixedValue() *
6939+ GroupSize);
6940+ StridedLoadTy = getWidenedType(ScalarTy, VecSz);
6941+ if (!TTI->isTypeLegal(StridedLoadTy) ||
6942+ !TTI->isLegalStridedLoadStore(StridedLoadTy, CommonAlignment))
6943+ return false;
68786944
6879- // Iterate through all pointers and check if all distances are
6880- // unique multiple of Dist.
6881- SmallSet<int64_t, 4> Dists;
6882- for (Value *Ptr : PointerOps) {
6883- int64_t Dist = 0;
6884- if (Ptr == PtrN)
6885- Dist = Diff;
6886- else if (Ptr != Ptr0)
6887- Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, *DL, *SE);
6888- // If the strides are not the same or repeated, we can't
6889- // vectorize.
6890- if (((Dist / Stride) * Stride) != Dist || !Dists.insert(Dist).second)
6945+ unsigned PrevGroupStartIdx = 0;
6946+ unsigned CurrentGroupStartIdx = GroupSize;
6947+ int64_t StrideBetweenGroups =
6948+ SortedOffsetsFromBase[GroupSize] - SortedOffsetsFromBase[0];
6949+ StrideIntVal = StrideBetweenGroups;
6950+ while (CurrentGroupStartIdx != Sz) {
6951+ if (SortedOffsetsFromBase[CurrentGroupStartIdx] -
6952+ SortedOffsetsFromBase[PrevGroupStartIdx] !=
6953+ StrideBetweenGroups)
68916954 break;
6955+ PrevGroupStartIdx = CurrentGroupStartIdx;
6956+ CurrentGroupStartIdx += GroupSize;
68926957 }
6893- if (Dists.size() == Sz) {
6894- Type *StrideTy = DL->getIndexType(Ptr0->getType());
6895- SPtrInfo.StrideVal = ConstantInt::get(StrideTy, Stride);
6896- SPtrInfo.Ty = getWidenedType(ScalarTy, Sz);
6897- return true;
6958+ if (CurrentGroupStartIdx != Sz)
6959+ return false;
6960+
6961+ auto CheckGroup = [&](unsigned StartIdx, unsigned GroupSize0,
6962+ int64_t StrideWithinGroup) -> bool {
6963+ unsigned GroupEndIdx = StartIdx + 1;
6964+ for (; GroupEndIdx != Sz; ++GroupEndIdx) {
6965+ if (SortedOffsetsFromBase[GroupEndIdx] -
6966+ SortedOffsetsFromBase[GroupEndIdx - 1] !=
6967+ StrideWithinGroup)
6968+ break;
6969+ }
6970+ return GroupEndIdx - StartIdx == GroupSize0;
6971+ };
6972+ for (unsigned I = 0; I < Sz; I += GroupSize) {
6973+ if (!CheckGroup(I, GroupSize, StrideWithinGroup))
6974+ return false;
68986975 }
68996976 }
6900- return false;
6977+
6978+ if (!isStridedLoad(PointerOps, ScalarTy, CommonAlignment, Diff, VecSz))
6979+ return false;
6980+
6981+ Type *StrideTy = DL->getIndexType(Ptr0->getType());
6982+ SPtrInfo.StrideVal = ConstantInt::get(StrideTy, StrideIntVal);
6983+ SPtrInfo.Ty = StridedLoadTy;
6984+ return true;
69016985}
69026986
69036987bool BoUpSLP::analyzeRtStrideCandidate(ArrayRef<Value *> PointerOps,
@@ -6995,8 +7079,8 @@ BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads(
69957079 Align Alignment =
69967080 cast<LoadInst>(Order.empty() ? VL.front() : VL[Order.front()])
69977081 ->getAlign();
6998- if (isStridedLoad (PointerOps, ScalarTy, Alignment, *Diff, Ptr0, PtrN ,
6999- SPtrInfo))
7082+ if (analyzeConstantStrideCandidate (PointerOps, ScalarTy, Alignment, Order ,
7083+ *Diff, Ptr0, PtrN, SPtrInfo))
70007084 return LoadsState::StridedVectorize;
70017085 }
70027086 if (!TTI->isLegalMaskedGather(VecTy, CommonAlignment) ||
@@ -14916,11 +15000,19 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
1491615000 }
1491715001 break;
1491815002 case TreeEntry::StridedVectorize: {
15003+ const StridedPtrInfo &SPtrInfo = TreeEntryToStridedPtrInfoMap.at(E);
15004+ FixedVectorType *StridedLoadTy = SPtrInfo.Ty;
15005+ assert(StridedLoadTy && "Missing StridedPoinerInfo for tree entry.");
1491915006 Align CommonAlignment =
1492015007 computeCommonAlignment<LoadInst>(UniqueValues.getArrayRef());
1492115008 VecLdCost = TTI->getStridedMemoryOpCost(
14922- Instruction::Load, VecTy , LI0->getPointerOperand(),
15009+ Instruction::Load, StridedLoadTy , LI0->getPointerOperand(),
1492315010 /*VariableMask=*/false, CommonAlignment, CostKind);
15011+ if (StridedLoadTy != VecTy)
15012+ VecLdCost +=
15013+ TTI->getCastInstrCost(Instruction::BitCast, VecTy, StridedLoadTy,
15014+ getCastContextHint(*E), CostKind);
15015+
1492415016 break;
1492515017 }
1492615018 case TreeEntry::CompressVectorize: {
@@ -19685,6 +19777,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
1968519777 ? NewLI
1968619778 : ::propagateMetadata(NewLI, E->Scalars);
1968719779
19780+ if (StridedLoadTy != VecTy)
19781+ V = Builder.CreateBitOrPointerCast(V, VecTy);
1968819782 V = FinalShuffle(V, E);
1968919783 E->VectorizedValue = V;
1969019784 ++NumVectorInstructions;
0 commit comments