Skip to content

Commit dd828e7

Browse files
author
Mikhail Gudim
committed
[SLPVectorizer] Widen constant strided loads.
Given a set of pointers, check if they can be rearranged as follows (%s is a constant): %b + 0 * %s + 0 %b + 0 * %s + 1 %b + 0 * %s + 2 ... %b + 0 * %s + w %b + 1 * %s + 0 %b + 1 * %s + 1 %b + 1 * %s + 2 ... %b + 1 * %s + w ... If the pointers can be rearanged in the above pattern, it means that the memory can be accessed with a strided loads of width `w` and stride `%s`.
1 parent 2512611 commit dd828e7

File tree

2 files changed

+134
-52
lines changed

2 files changed

+134
-52
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 131 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2242,8 +2242,29 @@ class BoUpSLP {
22422242
/// may not be necessary.
22432243
bool isLoadCombineCandidate(ArrayRef<Value *> Stores) const;
22442244
bool isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
2245-
Align Alignment, const int64_t Diff, Value *Ptr0,
2246-
Value *PtrN, StridedPtrInfo &SPtrInfo) const;
2245+
Align Alignment, int64_t Diff, size_t VecSz) const;
2246+
/// Given a set of pointers, check if they can be rearranged as follows (%s is
2247+
/// a constant):
2248+
/// %b + 0 * %s + 0
2249+
/// %b + 0 * %s + 1
2250+
/// %b + 0 * %s + 2
2251+
/// ...
2252+
/// %b + 0 * %s + w
2253+
///
2254+
/// %b + 1 * %s + 0
2255+
/// %b + 1 * %s + 1
2256+
/// %b + 1 * %s + 2
2257+
/// ...
2258+
/// %b + 1 * %s + w
2259+
/// ...
2260+
///
2261+
/// If the pointers can be rearanged in the above pattern, it means that the
2262+
/// memory can be accessed with a strided loads of width `w` and stride `%s`.
2263+
bool analyzeConstantStrideCandidate(ArrayRef<Value *> PointerOps,
2264+
Type *ElemTy, Align CommonAlignment,
2265+
SmallVectorImpl<unsigned> &SortedIndices,
2266+
int64_t Diff, Value *Ptr0, Value *PtrN,
2267+
StridedPtrInfo &SPtrInfo) const;
22472268

22482269
/// Checks if the given array of loads can be represented as a vectorized,
22492270
/// scatter or just simple gather.
@@ -6824,12 +6845,7 @@ isMaskedLoadCompress(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
68246845
/// current graph (for masked gathers extra extractelement instructions
68256846
/// might be required).
68266847
bool BoUpSLP::isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
6827-
Align Alignment, const int64_t Diff, Value *Ptr0,
6828-
Value *PtrN, StridedPtrInfo &SPtrInfo) const {
6829-
const size_t Sz = PointerOps.size();
6830-
if (Diff % (Sz - 1) != 0)
6831-
return false;
6832-
6848+
Align Alignment, int64_t Diff, size_t VecSz) const {
68336849
// Try to generate strided load node.
68346850
auto IsAnyPointerUsedOutGraph = any_of(PointerOps, [&](Value *V) {
68356851
return isa<Instruction>(V) && any_of(V->users(), [&](User *U) {
@@ -6838,41 +6854,109 @@ bool BoUpSLP::isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
68386854
});
68396855

68406856
const uint64_t AbsoluteDiff = std::abs(Diff);
6841-
auto *VecTy = getWidenedType(ScalarTy, Sz);
6857+
auto *VecTy = getWidenedType(ScalarTy, VecSz);
68426858
if (IsAnyPointerUsedOutGraph ||
6843-
(AbsoluteDiff > Sz &&
6844-
(Sz > MinProfitableStridedLoads ||
6845-
(AbsoluteDiff <= MaxProfitableLoadStride * Sz &&
6846-
AbsoluteDiff % Sz == 0 && has_single_bit(AbsoluteDiff / Sz)))) ||
6847-
Diff == -(static_cast<int64_t>(Sz) - 1)) {
6848-
int64_t Stride = Diff / static_cast<int64_t>(Sz - 1);
6849-
if (Diff != Stride * static_cast<int64_t>(Sz - 1))
6859+
(AbsoluteDiff > VecSz &&
6860+
(VecSz > MinProfitableStridedLoads ||
6861+
(AbsoluteDiff <= MaxProfitableLoadStride * VecSz &&
6862+
AbsoluteDiff % VecSz == 0 && has_single_bit(AbsoluteDiff / VecSz)))) ||
6863+
Diff == -(static_cast<int64_t>(VecSz) - 1)) {
6864+
int64_t Stride = Diff / static_cast<int64_t>(VecSz - 1);
6865+
if (Diff != Stride * static_cast<int64_t>(VecSz - 1))
68506866
return false;
68516867
if (!TTI->isLegalStridedLoadStore(VecTy, Alignment))
68526868
return false;
6869+
}
6870+
return true;
6871+
}
6872+
6873+
bool BoUpSLP::analyzeConstantStrideCandidate(
6874+
ArrayRef<Value *> PointerOps, Type *ElemTy, Align CommonAlignment,
6875+
SmallVectorImpl<unsigned> &SortedIndices, int64_t Diff, Value *Ptr0,
6876+
Value *PtrN, StridedPtrInfo &SPtrInfo) const {
6877+
const unsigned Sz = PointerOps.size();
6878+
SmallVector<int64_t> SortedOffsetsFromBase;
6879+
SortedOffsetsFromBase.resize(Sz);
6880+
for (unsigned I : seq<unsigned>(Sz)) {
6881+
Value *Ptr =
6882+
SortedIndices.empty() ? PointerOps[I] : PointerOps[SortedIndices[I]];
6883+
SortedOffsetsFromBase[I] =
6884+
*getPointersDiff(ElemTy, Ptr0, ElemTy, Ptr, *DL, *SE);
6885+
}
6886+
assert(SortedOffsetsFromBase.size() > 1 &&
6887+
"Trying to generate strided load for less than 2 loads");
6888+
//
6889+
// Find where the first group ends.
6890+
int64_t StrideWithinGroup =
6891+
SortedOffsetsFromBase[1] - SortedOffsetsFromBase[0];
6892+
unsigned GroupSize = 1;
6893+
for (; GroupSize != SortedOffsetsFromBase.size(); ++GroupSize) {
6894+
if (SortedOffsetsFromBase[GroupSize] -
6895+
SortedOffsetsFromBase[GroupSize - 1] !=
6896+
StrideWithinGroup)
6897+
break;
6898+
}
6899+
unsigned VecSz = Sz;
6900+
Type *ScalarTy = ElemTy;
6901+
int64_t StrideIntVal = StrideWithinGroup;
6902+
FixedVectorType *StridedLoadTy = getWidenedType(ScalarTy, VecSz);
6903+
6904+
if (Sz != GroupSize) {
6905+
if (Sz % GroupSize != 0)
6906+
return false;
6907+
VecSz = Sz / GroupSize;
6908+
6909+
if (StrideWithinGroup != 1)
6910+
return false;
6911+
unsigned VecSz = Sz / GroupSize;
6912+
ScalarTy = Type::getIntNTy(SE->getContext(),
6913+
DL->getTypeSizeInBits(ElemTy).getFixedValue() *
6914+
GroupSize);
6915+
StridedLoadTy = getWidenedType(ScalarTy, VecSz);
6916+
if (!TTI->isTypeLegal(StridedLoadTy) ||
6917+
!TTI->isLegalStridedLoadStore(StridedLoadTy, CommonAlignment))
6918+
return false;
68536919

6854-
// Iterate through all pointers and check if all distances are
6855-
// unique multiple of Dist.
6856-
SmallSet<int64_t, 4> Dists;
6857-
for (Value *Ptr : PointerOps) {
6858-
int64_t Dist = 0;
6859-
if (Ptr == PtrN)
6860-
Dist = Diff;
6861-
else if (Ptr != Ptr0)
6862-
Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, *DL, *SE);
6863-
// If the strides are not the same or repeated, we can't
6864-
// vectorize.
6865-
if (((Dist / Stride) * Stride) != Dist || !Dists.insert(Dist).second)
6920+
unsigned PrevGroupStartIdx = 0;
6921+
unsigned CurrentGroupStartIdx = GroupSize;
6922+
int64_t StrideBetweenGroups =
6923+
SortedOffsetsFromBase[GroupSize] - SortedOffsetsFromBase[0];
6924+
StrideIntVal = StrideBetweenGroups;
6925+
while (CurrentGroupStartIdx != Sz) {
6926+
if (SortedOffsetsFromBase[CurrentGroupStartIdx] -
6927+
SortedOffsetsFromBase[PrevGroupStartIdx] !=
6928+
StrideBetweenGroups)
68666929
break;
6930+
PrevGroupStartIdx = CurrentGroupStartIdx;
6931+
CurrentGroupStartIdx += GroupSize;
68676932
}
6868-
if (Dists.size() == Sz) {
6869-
Type *StrideTy = DL->getIndexType(Ptr0->getType());
6870-
SPtrInfo.StrideVal = ConstantInt::get(StrideTy, Stride);
6871-
SPtrInfo.Ty = getWidenedType(ScalarTy, Sz);
6872-
return true;
6933+
if (CurrentGroupStartIdx != Sz)
6934+
return false;
6935+
6936+
auto CheckGroup = [&](unsigned StartIdx, unsigned GroupSize0,
6937+
int64_t StrideWithinGroup) -> bool {
6938+
unsigned GroupEndIdx = StartIdx + 1;
6939+
for (; GroupEndIdx != Sz; ++GroupEndIdx) {
6940+
if (SortedOffsetsFromBase[GroupEndIdx] -
6941+
SortedOffsetsFromBase[GroupEndIdx - 1] !=
6942+
StrideWithinGroup)
6943+
break;
6944+
}
6945+
return GroupEndIdx - StartIdx == GroupSize0;
6946+
};
6947+
for (unsigned I = 0; I < Sz; I += GroupSize) {
6948+
if (!CheckGroup(I, GroupSize, StrideWithinGroup))
6949+
return false;
68736950
}
68746951
}
6875-
return false;
6952+
6953+
if (!isStridedLoad(PointerOps, ScalarTy, CommonAlignment, Diff, VecSz))
6954+
return false;
6955+
6956+
Type *StrideTy = DL->getIndexType(Ptr0->getType());
6957+
SPtrInfo.StrideVal = ConstantInt::get(StrideTy, StrideIntVal);
6958+
SPtrInfo.Ty = StridedLoadTy;
6959+
return true;
68766960
}
68776961

68786962
BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads(
@@ -6958,8 +7042,8 @@ BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads(
69587042
Align Alignment =
69597043
cast<LoadInst>(Order.empty() ? VL.front() : VL[Order.front()])
69607044
->getAlign();
6961-
if (isStridedLoad(PointerOps, ScalarTy, Alignment, *Diff, Ptr0, PtrN,
6962-
SPtrInfo))
7045+
if (analyzeConstantStrideCandidate(PointerOps, ScalarTy, Alignment, Order,
7046+
*Diff, Ptr0, PtrN, SPtrInfo))
69637047
return LoadsState::StridedVectorize;
69647048
}
69657049
if (!TTI->isLegalMaskedGather(VecTy, CommonAlignment) ||
@@ -14865,11 +14949,19 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
1486514949
}
1486614950
break;
1486714951
case TreeEntry::StridedVectorize: {
14952+
const StridedPtrInfo &SPtrInfo = TreeEntryToStridedPtrInfoMap.at(E);
14953+
FixedVectorType *StridedLoadTy = SPtrInfo.Ty;
14954+
assert(StridedLoadTy && "Missing StridedPoinerInfo for tree entry.");
1486814955
Align CommonAlignment =
1486914956
computeCommonAlignment<LoadInst>(UniqueValues.getArrayRef());
1487014957
VecLdCost = TTI->getStridedMemoryOpCost(
14871-
Instruction::Load, VecTy, LI0->getPointerOperand(),
14958+
Instruction::Load, StridedLoadTy, LI0->getPointerOperand(),
1487214959
/*VariableMask=*/false, CommonAlignment, CostKind);
14960+
if (StridedLoadTy != VecTy)
14961+
VecLdCost +=
14962+
TTI->getCastInstrCost(Instruction::BitCast, VecTy, StridedLoadTy,
14963+
getCastContextHint(*E), CostKind);
14964+
1487314965
break;
1487414966
}
1487514967
case TreeEntry::CompressVectorize: {
@@ -19633,6 +19725,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
1963319725
? NewLI
1963419726
: ::propagateMetadata(NewLI, E->Scalars);
1963519727

19728+
if (StridedLoadTy != VecTy)
19729+
V = Builder.CreateBitOrPointerCast(V, VecTy);
1963619730
V = FinalShuffle(V, E);
1963719731
E->VectorizedValue = V;
1963819732
++NumVectorInstructions;

llvm/test/Transforms/SLPVectorizer/RISCV/basic-strided-loads.ll

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -621,22 +621,10 @@ define void @constant_stride_widen_no_reordering(ptr %pl, i64 %stride, ptr %ps)
621621
; CHECK-LABEL: define void @constant_stride_widen_no_reordering(
622622
; CHECK-SAME: ptr [[PL:%.*]], i64 [[STRIDE:%.*]], ptr [[PS:%.*]]) #[[ATTR0]] {
623623
; CHECK-NEXT: [[GEP_L0:%.*]] = getelementptr inbounds i8, ptr [[PL]], i64 0
624-
; CHECK-NEXT: [[GEP_L4:%.*]] = getelementptr inbounds i8, ptr [[PL]], i64 100
625-
; CHECK-NEXT: [[GEP_L8:%.*]] = getelementptr inbounds i8, ptr [[PL]], i64 200
626-
; CHECK-NEXT: [[GEP_L12:%.*]] = getelementptr inbounds i8, ptr [[PL]], i64 300
627624
; CHECK-NEXT: [[GEP_S0:%.*]] = getelementptr inbounds i8, ptr [[PS]], i64 0
628-
; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i8>, ptr [[GEP_L0]], align 1
629-
; CHECK-NEXT: [[TMP2:%.*]] = load <4 x i8>, ptr [[GEP_L4]], align 1
630-
; CHECK-NEXT: [[TMP3:%.*]] = load <4 x i8>, ptr [[GEP_L8]], align 1
631-
; CHECK-NEXT: [[TMP4:%.*]] = load <4 x i8>, ptr [[GEP_L12]], align 1
632-
; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x i8> [[TMP1]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
633-
; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <4 x i8> [[TMP2]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
634-
; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <4 x i8> [[TMP1]], <4 x i8> [[TMP2]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
635-
; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <4 x i8> [[TMP3]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
636-
; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <16 x i8> [[TMP7]], <16 x i8> [[TMP11]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 16, i32 17, i32 18, i32 19, i32 poison, i32 poison, i32 poison, i32 poison>
637-
; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <4 x i8> [[TMP4]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
638-
; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <16 x i8> [[TMP9]], <16 x i8> [[TMP10]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 16, i32 17, i32 18, i32 19>
639-
; CHECK-NEXT: store <16 x i8> [[TMP8]], ptr [[GEP_S0]], align 1
625+
; CHECK-NEXT: [[TMP1:%.*]] = call <4 x i32> @llvm.experimental.vp.strided.load.v4i32.p0.i64(ptr align 16 [[GEP_L0]], i64 100, <4 x i1> splat (i1 true), i32 4)
626+
; CHECK-NEXT: [[TMP8:%.*]] = bitcast <4 x i32> [[TMP1]] to <16 x i8>
627+
; CHECK-NEXT: store <16 x i8> [[TMP8]], ptr [[GEP_S0]], align 16
640628
; CHECK-NEXT: ret void
641629
;
642630
%gep_l0 = getelementptr inbounds i8, ptr %pl, i64 0

0 commit comments

Comments
 (0)