Skip to content

Commit ca08ba0

Browse files
mgudimgithub-actions[bot]
authored andcommitted
Automerge: [SLPVectorizer] Refactor isStridedLoad, NFC. (#163844)
Move the checks that all strides are the same from `isStridedLoad` to a new function `analyzeConstantStrideCandidate`. This is to reduce the diff for the following MRs which will modify the logic in `analyzeConstantStrideCandidate` to cover the case of widening of the strided load. All the checks that are left in `isStridedLoad` will be reused.
2 parents aa3097d + eb5de5c commit ca08ba0

File tree

1 file changed

+80
-27
lines changed

1 file changed

+80
-27
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 80 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2242,8 +2242,49 @@ class BoUpSLP {
22422242
/// may not be necessary.
22432243
bool isLoadCombineCandidate(ArrayRef<Value *> Stores) const;
22442244
bool isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
2245-
Align Alignment, const int64_t Diff, Value *Ptr0,
2246-
Value *PtrN, StridedPtrInfo &SPtrInfo) const;
2245+
Align Alignment, const int64_t Diff,
2246+
const size_t Sz) const;
2247+
2248+
/// Return true if an array of scalar loads can be replaced with a strided
2249+
/// load (with constant stride).
2250+
///
2251+
/// TODO:
2252+
/// It is possible that the load gets "widened". Suppose that originally each
2253+
/// load loads `k` bytes and `PointerOps` can be arranged as follows (`%s` is
2254+
/// constant): %b + 0 * %s + 0 %b + 0 * %s + 1 %b + 0 * %s + 2
2255+
/// ...
2256+
/// %b + 0 * %s + (w - 1)
2257+
///
2258+
/// %b + 1 * %s + 0
2259+
/// %b + 1 * %s + 1
2260+
/// %b + 1 * %s + 2
2261+
/// ...
2262+
/// %b + 1 * %s + (w - 1)
2263+
/// ...
2264+
///
2265+
/// %b + (n - 1) * %s + 0
2266+
/// %b + (n - 1) * %s + 1
2267+
/// %b + (n - 1) * %s + 2
2268+
/// ...
2269+
/// %b + (n - 1) * %s + (w - 1)
2270+
///
2271+
/// In this case we will generate a strided load of type `<n x (k * w)>`.
2272+
///
2273+
/// \param PointerOps list of pointer arguments of loads.
2274+
/// \param ElemTy original scalar type of loads.
2275+
/// \param Alignment alignment of the first load.
2276+
/// \param SortedIndices is the order of PointerOps as returned by
2277+
/// `sortPtrAccesses`
2278+
/// \param Diff Pointer difference between the lowest and the highes pointer
2279+
/// in `PointerOps` as returned by `getPointersDiff`.
2280+
/// \param Ptr0 first pointer in `PointersOps`.
2281+
/// \param PtrN last pointer in `PointersOps`.
2282+
/// \param SPtrInfo If the function return `true`, it also sets all the fields
2283+
/// of `SPtrInfo` necessary to generate the strided load later.
2284+
bool analyzeConstantStrideCandidate(
2285+
const ArrayRef<Value *> PointerOps, Type *ElemTy, Align Alignment,
2286+
const SmallVectorImpl<unsigned> &SortedIndices, const int64_t Diff,
2287+
Value *Ptr0, Value *PtrN, StridedPtrInfo &SPtrInfo) const;
22472288

22482289
/// Return true if an array of scalar loads can be replaced with a strided
22492290
/// load (with run-time stride).
@@ -6849,9 +6890,8 @@ isMaskedLoadCompress(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
68496890
/// current graph (for masked gathers extra extractelement instructions
68506891
/// might be required).
68516892
bool BoUpSLP::isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
6852-
Align Alignment, const int64_t Diff, Value *Ptr0,
6853-
Value *PtrN, StridedPtrInfo &SPtrInfo) const {
6854-
const size_t Sz = PointerOps.size();
6893+
Align Alignment, const int64_t Diff,
6894+
const size_t Sz) const {
68556895
if (Diff % (Sz - 1) != 0)
68566896
return false;
68576897

@@ -6875,27 +6915,40 @@ bool BoUpSLP::isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
68756915
return false;
68766916
if (!TTI->isLegalStridedLoadStore(VecTy, Alignment))
68776917
return false;
6918+
return true;
6919+
}
6920+
return false;
6921+
}
68786922

6879-
// Iterate through all pointers and check if all distances are
6880-
// unique multiple of Dist.
6881-
SmallSet<int64_t, 4> Dists;
6882-
for (Value *Ptr : PointerOps) {
6883-
int64_t Dist = 0;
6884-
if (Ptr == PtrN)
6885-
Dist = Diff;
6886-
else if (Ptr != Ptr0)
6887-
Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, *DL, *SE);
6888-
// If the strides are not the same or repeated, we can't
6889-
// vectorize.
6890-
if (((Dist / Stride) * Stride) != Dist || !Dists.insert(Dist).second)
6891-
break;
6892-
}
6893-
if (Dists.size() == Sz) {
6894-
Type *StrideTy = DL->getIndexType(Ptr0->getType());
6895-
SPtrInfo.StrideVal = ConstantInt::get(StrideTy, Stride);
6896-
SPtrInfo.Ty = getWidenedType(ScalarTy, Sz);
6897-
return true;
6898-
}
6923+
bool BoUpSLP::analyzeConstantStrideCandidate(
6924+
const ArrayRef<Value *> PointerOps, Type *ScalarTy, Align Alignment,
6925+
const SmallVectorImpl<unsigned> &SortedIndices, const int64_t Diff,
6926+
Value *Ptr0, Value *PtrN, StridedPtrInfo &SPtrInfo) const {
6927+
const size_t Sz = PointerOps.size();
6928+
if (!isStridedLoad(PointerOps, ScalarTy, Alignment, Diff, Sz))
6929+
return false;
6930+
6931+
int64_t Stride = Diff / static_cast<int64_t>(Sz - 1);
6932+
6933+
// Iterate through all pointers and check if all distances are
6934+
// unique multiple of Dist.
6935+
SmallSet<int64_t, 4> Dists;
6936+
for (Value *Ptr : PointerOps) {
6937+
int64_t Dist = 0;
6938+
if (Ptr == PtrN)
6939+
Dist = Diff;
6940+
else if (Ptr != Ptr0)
6941+
Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, *DL, *SE);
6942+
// If the strides are not the same or repeated, we can't
6943+
// vectorize.
6944+
if (((Dist / Stride) * Stride) != Dist || !Dists.insert(Dist).second)
6945+
break;
6946+
}
6947+
if (Dists.size() == Sz) {
6948+
Type *StrideTy = DL->getIndexType(Ptr0->getType());
6949+
SPtrInfo.StrideVal = ConstantInt::get(StrideTy, Stride);
6950+
SPtrInfo.Ty = getWidenedType(ScalarTy, Sz);
6951+
return true;
68996952
}
69006953
return false;
69016954
}
@@ -6995,8 +7048,8 @@ BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads(
69957048
Align Alignment =
69967049
cast<LoadInst>(Order.empty() ? VL.front() : VL[Order.front()])
69977050
->getAlign();
6998-
if (isStridedLoad(PointerOps, ScalarTy, Alignment, *Diff, Ptr0, PtrN,
6999-
SPtrInfo))
7051+
if (analyzeConstantStrideCandidate(PointerOps, ScalarTy, Alignment, Order,
7052+
*Diff, Ptr0, PtrN, SPtrInfo))
70007053
return LoadsState::StridedVectorize;
70017054
}
70027055
if (!TTI->isLegalMaskedGather(VecTy, CommonAlignment) ||

0 commit comments

Comments
 (0)