@@ -2242,8 +2242,49 @@ class BoUpSLP {
22422242 /// may not be necessary.
22432243 bool isLoadCombineCandidate(ArrayRef<Value *> Stores) const;
22442244 bool isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
2245- Align Alignment, const int64_t Diff, Value *Ptr0,
2246- Value *PtrN, StridedPtrInfo &SPtrInfo) const;
2245+ Align Alignment, const int64_t Diff,
2246+ const size_t Sz) const;
2247+
2248+ /// Return true if an array of scalar loads can be replaced with a strided
2249+ /// load (with constant stride).
2250+ ///
2251+ /// TODO:
2252+ /// It is possible that the load gets "widened". Suppose that originally each
2253+ /// load loads `k` bytes and `PointerOps` can be arranged as follows (`%s` is
2254+ /// constant): %b + 0 * %s + 0 %b + 0 * %s + 1 %b + 0 * %s + 2
2255+ /// ...
2256+ /// %b + 0 * %s + (w - 1)
2257+ ///
2258+ /// %b + 1 * %s + 0
2259+ /// %b + 1 * %s + 1
2260+ /// %b + 1 * %s + 2
2261+ /// ...
2262+ /// %b + 1 * %s + (w - 1)
2263+ /// ...
2264+ ///
2265+ /// %b + (n - 1) * %s + 0
2266+ /// %b + (n - 1) * %s + 1
2267+ /// %b + (n - 1) * %s + 2
2268+ /// ...
2269+ /// %b + (n - 1) * %s + (w - 1)
2270+ ///
2271+ /// In this case we will generate a strided load of type `<n x (k * w)>`.
2272+ ///
2273+ /// \param PointerOps list of pointer arguments of loads.
2274+ /// \param ElemTy original scalar type of loads.
2275+ /// \param Alignment alignment of the first load.
2276+ /// \param SortedIndices is the order of PointerOps as returned by
2277+ /// `sortPtrAccesses`
2278+ /// \param Diff Pointer difference between the lowest and the highes pointer
2279+ /// in `PointerOps` as returned by `getPointersDiff`.
2280+ /// \param Ptr0 first pointer in `PointersOps`.
2281+ /// \param PtrN last pointer in `PointersOps`.
2282+ /// \param SPtrInfo If the function return `true`, it also sets all the fields
2283+ /// of `SPtrInfo` necessary to generate the strided load later.
2284+ bool analyzeConstantStrideCandidate(
2285+ const ArrayRef<Value *> PointerOps, Type *ElemTy, Align Alignment,
2286+ const SmallVectorImpl<unsigned> &SortedIndices, const int64_t Diff,
2287+ Value *Ptr0, Value *PtrN, StridedPtrInfo &SPtrInfo) const;
22472288
22482289 /// Return true if an array of scalar loads can be replaced with a strided
22492290 /// load (with run-time stride).
@@ -6849,9 +6890,8 @@ isMaskedLoadCompress(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
68496890/// current graph (for masked gathers extra extractelement instructions
68506891/// might be required).
68516892bool BoUpSLP::isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
6852- Align Alignment, const int64_t Diff, Value *Ptr0,
6853- Value *PtrN, StridedPtrInfo &SPtrInfo) const {
6854- const size_t Sz = PointerOps.size();
6893+ Align Alignment, const int64_t Diff,
6894+ const size_t Sz) const {
68556895 if (Diff % (Sz - 1) != 0)
68566896 return false;
68576897
@@ -6875,27 +6915,40 @@ bool BoUpSLP::isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
68756915 return false;
68766916 if (!TTI->isLegalStridedLoadStore(VecTy, Alignment))
68776917 return false;
6918+ return true;
6919+ }
6920+ return false;
6921+ }
68786922
6879- // Iterate through all pointers and check if all distances are
6880- // unique multiple of Dist.
6881- SmallSet<int64_t, 4> Dists;
6882- for (Value *Ptr : PointerOps) {
6883- int64_t Dist = 0;
6884- if (Ptr == PtrN)
6885- Dist = Diff;
6886- else if (Ptr != Ptr0)
6887- Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, *DL, *SE);
6888- // If the strides are not the same or repeated, we can't
6889- // vectorize.
6890- if (((Dist / Stride) * Stride) != Dist || !Dists.insert(Dist).second)
6891- break;
6892- }
6893- if (Dists.size() == Sz) {
6894- Type *StrideTy = DL->getIndexType(Ptr0->getType());
6895- SPtrInfo.StrideVal = ConstantInt::get(StrideTy, Stride);
6896- SPtrInfo.Ty = getWidenedType(ScalarTy, Sz);
6897- return true;
6898- }
6923+ bool BoUpSLP::analyzeConstantStrideCandidate(
6924+ const ArrayRef<Value *> PointerOps, Type *ScalarTy, Align Alignment,
6925+ const SmallVectorImpl<unsigned> &SortedIndices, const int64_t Diff,
6926+ Value *Ptr0, Value *PtrN, StridedPtrInfo &SPtrInfo) const {
6927+ const size_t Sz = PointerOps.size();
6928+ if (!isStridedLoad(PointerOps, ScalarTy, Alignment, Diff, Sz))
6929+ return false;
6930+
6931+ int64_t Stride = Diff / static_cast<int64_t>(Sz - 1);
6932+
6933+ // Iterate through all pointers and check if all distances are
6934+ // unique multiple of Dist.
6935+ SmallSet<int64_t, 4> Dists;
6936+ for (Value *Ptr : PointerOps) {
6937+ int64_t Dist = 0;
6938+ if (Ptr == PtrN)
6939+ Dist = Diff;
6940+ else if (Ptr != Ptr0)
6941+ Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, *DL, *SE);
6942+ // If the strides are not the same or repeated, we can't
6943+ // vectorize.
6944+ if (((Dist / Stride) * Stride) != Dist || !Dists.insert(Dist).second)
6945+ break;
6946+ }
6947+ if (Dists.size() == Sz) {
6948+ Type *StrideTy = DL->getIndexType(Ptr0->getType());
6949+ SPtrInfo.StrideVal = ConstantInt::get(StrideTy, Stride);
6950+ SPtrInfo.Ty = getWidenedType(ScalarTy, Sz);
6951+ return true;
68996952 }
69006953 return false;
69016954}
@@ -6995,8 +7048,8 @@ BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads(
69957048 Align Alignment =
69967049 cast<LoadInst>(Order.empty() ? VL.front() : VL[Order.front()])
69977050 ->getAlign();
6998- if (isStridedLoad (PointerOps, ScalarTy, Alignment, *Diff, Ptr0, PtrN ,
6999- SPtrInfo))
7051+ if (analyzeConstantStrideCandidate (PointerOps, ScalarTy, Alignment, Order ,
7052+ *Diff, Ptr0, PtrN, SPtrInfo))
70007053 return LoadsState::StridedVectorize;
70017054 }
70027055 if (!TTI->isLegalMaskedGather(VecTy, CommonAlignment) ||
0 commit comments