Skip to content

Commit 0df255c

Browse files
author
Mikhail Gudim
committed
[SLPVectorizer] Widen constant strided loads.
Given a set of pointers, check if they can be rearranged as follows (%s is a constant): %b + 0 * %s + 0 %b + 0 * %s + 1 %b + 0 * %s + 2 ... %b + 0 * %s + w %b + 1 * %s + 0 %b + 1 * %s + 1 %b + 1 * %s + 2 ... %b + 1 * %s + w ... If the pointers can be rearanged in the above pattern, it means that the memory can be accessed with a strided loads of width `w` and stride `%s`.
1 parent 975fba1 commit 0df255c

File tree

2 files changed

+134
-52
lines changed

2 files changed

+134
-52
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 131 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2242,8 +2242,29 @@ class BoUpSLP {
22422242
/// may not be necessary.
22432243
bool isLoadCombineCandidate(ArrayRef<Value *> Stores) const;
22442244
bool isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
2245-
Align Alignment, const int64_t Diff, Value *Ptr0,
2246-
Value *PtrN, StridedPtrInfo &SPtrInfo) const;
2245+
Align Alignment, int64_t Diff, size_t VecSz) const;
2246+
/// Given a set of pointers, check if they can be rearranged as follows (%s is
2247+
/// a constant):
2248+
/// %b + 0 * %s + 0
2249+
/// %b + 0 * %s + 1
2250+
/// %b + 0 * %s + 2
2251+
/// ...
2252+
/// %b + 0 * %s + w
2253+
///
2254+
/// %b + 1 * %s + 0
2255+
/// %b + 1 * %s + 1
2256+
/// %b + 1 * %s + 2
2257+
/// ...
2258+
/// %b + 1 * %s + w
2259+
/// ...
2260+
///
2261+
/// If the pointers can be rearanged in the above pattern, it means that the
2262+
/// memory can be accessed with a strided loads of width `w` and stride `%s`.
2263+
bool analyzeConstantStrideCandidate(ArrayRef<Value *> PointerOps,
2264+
Type *ElemTy, Align CommonAlignment,
2265+
SmallVectorImpl<unsigned> &SortedIndices,
2266+
int64_t Diff, Value *Ptr0, Value *PtrN,
2267+
StridedPtrInfo &SPtrInfo) const;
22472268

22482269
/// Return true if an array of scalar loads can be replaced with a strided
22492270
/// load (with run-time stride).
@@ -6849,12 +6870,7 @@ isMaskedLoadCompress(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
68496870
/// current graph (for masked gathers extra extractelement instructions
68506871
/// might be required).
68516872
bool BoUpSLP::isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
6852-
Align Alignment, const int64_t Diff, Value *Ptr0,
6853-
Value *PtrN, StridedPtrInfo &SPtrInfo) const {
6854-
const size_t Sz = PointerOps.size();
6855-
if (Diff % (Sz - 1) != 0)
6856-
return false;
6857-
6873+
Align Alignment, int64_t Diff, size_t VecSz) const {
68586874
// Try to generate strided load node.
68596875
auto IsAnyPointerUsedOutGraph = any_of(PointerOps, [&](Value *V) {
68606876
return isa<Instruction>(V) && any_of(V->users(), [&](User *U) {
@@ -6863,41 +6879,109 @@ bool BoUpSLP::isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
68636879
});
68646880

68656881
const uint64_t AbsoluteDiff = std::abs(Diff);
6866-
auto *VecTy = getWidenedType(ScalarTy, Sz);
6882+
auto *VecTy = getWidenedType(ScalarTy, VecSz);
68676883
if (IsAnyPointerUsedOutGraph ||
6868-
(AbsoluteDiff > Sz &&
6869-
(Sz > MinProfitableStridedLoads ||
6870-
(AbsoluteDiff <= MaxProfitableLoadStride * Sz &&
6871-
AbsoluteDiff % Sz == 0 && has_single_bit(AbsoluteDiff / Sz)))) ||
6872-
Diff == -(static_cast<int64_t>(Sz) - 1)) {
6873-
int64_t Stride = Diff / static_cast<int64_t>(Sz - 1);
6874-
if (Diff != Stride * static_cast<int64_t>(Sz - 1))
6884+
(AbsoluteDiff > VecSz &&
6885+
(VecSz > MinProfitableStridedLoads ||
6886+
(AbsoluteDiff <= MaxProfitableLoadStride * VecSz &&
6887+
AbsoluteDiff % VecSz == 0 && has_single_bit(AbsoluteDiff / VecSz)))) ||
6888+
Diff == -(static_cast<int64_t>(VecSz) - 1)) {
6889+
int64_t Stride = Diff / static_cast<int64_t>(VecSz - 1);
6890+
if (Diff != Stride * static_cast<int64_t>(VecSz - 1))
68756891
return false;
68766892
if (!TTI->isLegalStridedLoadStore(VecTy, Alignment))
68776893
return false;
6894+
}
6895+
return true;
6896+
}
6897+
6898+
bool BoUpSLP::analyzeConstantStrideCandidate(
6899+
ArrayRef<Value *> PointerOps, Type *ElemTy, Align CommonAlignment,
6900+
SmallVectorImpl<unsigned> &SortedIndices, int64_t Diff, Value *Ptr0,
6901+
Value *PtrN, StridedPtrInfo &SPtrInfo) const {
6902+
const unsigned Sz = PointerOps.size();
6903+
SmallVector<int64_t> SortedOffsetsFromBase;
6904+
SortedOffsetsFromBase.resize(Sz);
6905+
for (unsigned I : seq<unsigned>(Sz)) {
6906+
Value *Ptr =
6907+
SortedIndices.empty() ? PointerOps[I] : PointerOps[SortedIndices[I]];
6908+
SortedOffsetsFromBase[I] =
6909+
*getPointersDiff(ElemTy, Ptr0, ElemTy, Ptr, *DL, *SE);
6910+
}
6911+
assert(SortedOffsetsFromBase.size() > 1 &&
6912+
"Trying to generate strided load for less than 2 loads");
6913+
//
6914+
// Find where the first group ends.
6915+
int64_t StrideWithinGroup =
6916+
SortedOffsetsFromBase[1] - SortedOffsetsFromBase[0];
6917+
unsigned GroupSize = 1;
6918+
for (; GroupSize != SortedOffsetsFromBase.size(); ++GroupSize) {
6919+
if (SortedOffsetsFromBase[GroupSize] -
6920+
SortedOffsetsFromBase[GroupSize - 1] !=
6921+
StrideWithinGroup)
6922+
break;
6923+
}
6924+
unsigned VecSz = Sz;
6925+
Type *ScalarTy = ElemTy;
6926+
int64_t StrideIntVal = StrideWithinGroup;
6927+
FixedVectorType *StridedLoadTy = getWidenedType(ScalarTy, VecSz);
6928+
6929+
if (Sz != GroupSize) {
6930+
if (Sz % GroupSize != 0)
6931+
return false;
6932+
VecSz = Sz / GroupSize;
6933+
6934+
if (StrideWithinGroup != 1)
6935+
return false;
6936+
unsigned VecSz = Sz / GroupSize;
6937+
ScalarTy = Type::getIntNTy(SE->getContext(),
6938+
DL->getTypeSizeInBits(ElemTy).getFixedValue() *
6939+
GroupSize);
6940+
StridedLoadTy = getWidenedType(ScalarTy, VecSz);
6941+
if (!TTI->isTypeLegal(StridedLoadTy) ||
6942+
!TTI->isLegalStridedLoadStore(StridedLoadTy, CommonAlignment))
6943+
return false;
68786944

6879-
// Iterate through all pointers and check if all distances are
6880-
// unique multiple of Dist.
6881-
SmallSet<int64_t, 4> Dists;
6882-
for (Value *Ptr : PointerOps) {
6883-
int64_t Dist = 0;
6884-
if (Ptr == PtrN)
6885-
Dist = Diff;
6886-
else if (Ptr != Ptr0)
6887-
Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, *DL, *SE);
6888-
// If the strides are not the same or repeated, we can't
6889-
// vectorize.
6890-
if (((Dist / Stride) * Stride) != Dist || !Dists.insert(Dist).second)
6945+
unsigned PrevGroupStartIdx = 0;
6946+
unsigned CurrentGroupStartIdx = GroupSize;
6947+
int64_t StrideBetweenGroups =
6948+
SortedOffsetsFromBase[GroupSize] - SortedOffsetsFromBase[0];
6949+
StrideIntVal = StrideBetweenGroups;
6950+
while (CurrentGroupStartIdx != Sz) {
6951+
if (SortedOffsetsFromBase[CurrentGroupStartIdx] -
6952+
SortedOffsetsFromBase[PrevGroupStartIdx] !=
6953+
StrideBetweenGroups)
68916954
break;
6955+
PrevGroupStartIdx = CurrentGroupStartIdx;
6956+
CurrentGroupStartIdx += GroupSize;
68926957
}
6893-
if (Dists.size() == Sz) {
6894-
Type *StrideTy = DL->getIndexType(Ptr0->getType());
6895-
SPtrInfo.StrideVal = ConstantInt::get(StrideTy, Stride);
6896-
SPtrInfo.Ty = getWidenedType(ScalarTy, Sz);
6897-
return true;
6958+
if (CurrentGroupStartIdx != Sz)
6959+
return false;
6960+
6961+
auto CheckGroup = [&](unsigned StartIdx, unsigned GroupSize0,
6962+
int64_t StrideWithinGroup) -> bool {
6963+
unsigned GroupEndIdx = StartIdx + 1;
6964+
for (; GroupEndIdx != Sz; ++GroupEndIdx) {
6965+
if (SortedOffsetsFromBase[GroupEndIdx] -
6966+
SortedOffsetsFromBase[GroupEndIdx - 1] !=
6967+
StrideWithinGroup)
6968+
break;
6969+
}
6970+
return GroupEndIdx - StartIdx == GroupSize0;
6971+
};
6972+
for (unsigned I = 0; I < Sz; I += GroupSize) {
6973+
if (!CheckGroup(I, GroupSize, StrideWithinGroup))
6974+
return false;
68986975
}
68996976
}
6900-
return false;
6977+
6978+
if (!isStridedLoad(PointerOps, ScalarTy, CommonAlignment, Diff, VecSz))
6979+
return false;
6980+
6981+
Type *StrideTy = DL->getIndexType(Ptr0->getType());
6982+
SPtrInfo.StrideVal = ConstantInt::get(StrideTy, StrideIntVal);
6983+
SPtrInfo.Ty = StridedLoadTy;
6984+
return true;
69016985
}
69026986

69036987
bool BoUpSLP::analyzeRtStrideCandidate(ArrayRef<Value *> PointerOps,
@@ -6995,8 +7079,8 @@ BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads(
69957079
Align Alignment =
69967080
cast<LoadInst>(Order.empty() ? VL.front() : VL[Order.front()])
69977081
->getAlign();
6998-
if (isStridedLoad(PointerOps, ScalarTy, Alignment, *Diff, Ptr0, PtrN,
6999-
SPtrInfo))
7082+
if (analyzeConstantStrideCandidate(PointerOps, ScalarTy, Alignment, Order,
7083+
*Diff, Ptr0, PtrN, SPtrInfo))
70007084
return LoadsState::StridedVectorize;
70017085
}
70027086
if (!TTI->isLegalMaskedGather(VecTy, CommonAlignment) ||
@@ -14916,11 +15000,19 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
1491615000
}
1491715001
break;
1491815002
case TreeEntry::StridedVectorize: {
15003+
const StridedPtrInfo &SPtrInfo = TreeEntryToStridedPtrInfoMap.at(E);
15004+
FixedVectorType *StridedLoadTy = SPtrInfo.Ty;
15005+
assert(StridedLoadTy && "Missing StridedPoinerInfo for tree entry.");
1491915006
Align CommonAlignment =
1492015007
computeCommonAlignment<LoadInst>(UniqueValues.getArrayRef());
1492115008
VecLdCost = TTI->getStridedMemoryOpCost(
14922-
Instruction::Load, VecTy, LI0->getPointerOperand(),
15009+
Instruction::Load, StridedLoadTy, LI0->getPointerOperand(),
1492315010
/*VariableMask=*/false, CommonAlignment, CostKind);
15011+
if (StridedLoadTy != VecTy)
15012+
VecLdCost +=
15013+
TTI->getCastInstrCost(Instruction::BitCast, VecTy, StridedLoadTy,
15014+
getCastContextHint(*E), CostKind);
15015+
1492415016
break;
1492515017
}
1492615018
case TreeEntry::CompressVectorize: {
@@ -19685,6 +19777,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
1968519777
? NewLI
1968619778
: ::propagateMetadata(NewLI, E->Scalars);
1968719779

19780+
if (StridedLoadTy != VecTy)
19781+
V = Builder.CreateBitOrPointerCast(V, VecTy);
1968819782
V = FinalShuffle(V, E);
1968919783
E->VectorizedValue = V;
1969019784
++NumVectorInstructions;

llvm/test/Transforms/SLPVectorizer/RISCV/basic-strided-loads.ll

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -621,22 +621,10 @@ define void @constant_stride_widen_no_reordering(ptr %pl, i64 %stride, ptr %ps)
621621
; CHECK-LABEL: define void @constant_stride_widen_no_reordering(
622622
; CHECK-SAME: ptr [[PL:%.*]], i64 [[STRIDE:%.*]], ptr [[PS:%.*]]) #[[ATTR0]] {
623623
; CHECK-NEXT: [[GEP_L0:%.*]] = getelementptr inbounds i8, ptr [[PL]], i64 0
624-
; CHECK-NEXT: [[GEP_L4:%.*]] = getelementptr inbounds i8, ptr [[PL]], i64 100
625-
; CHECK-NEXT: [[GEP_L8:%.*]] = getelementptr inbounds i8, ptr [[PL]], i64 200
626-
; CHECK-NEXT: [[GEP_L12:%.*]] = getelementptr inbounds i8, ptr [[PL]], i64 300
627624
; CHECK-NEXT: [[GEP_S0:%.*]] = getelementptr inbounds i8, ptr [[PS]], i64 0
628-
; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i8>, ptr [[GEP_L0]], align 1
629-
; CHECK-NEXT: [[TMP2:%.*]] = load <4 x i8>, ptr [[GEP_L4]], align 1
630-
; CHECK-NEXT: [[TMP3:%.*]] = load <4 x i8>, ptr [[GEP_L8]], align 1
631-
; CHECK-NEXT: [[TMP4:%.*]] = load <4 x i8>, ptr [[GEP_L12]], align 1
632-
; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x i8> [[TMP1]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
633-
; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <4 x i8> [[TMP2]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
634-
; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <4 x i8> [[TMP1]], <4 x i8> [[TMP2]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
635-
; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <4 x i8> [[TMP3]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
636-
; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <16 x i8> [[TMP7]], <16 x i8> [[TMP11]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 16, i32 17, i32 18, i32 19, i32 poison, i32 poison, i32 poison, i32 poison>
637-
; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <4 x i8> [[TMP4]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
638-
; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <16 x i8> [[TMP9]], <16 x i8> [[TMP10]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 16, i32 17, i32 18, i32 19>
639-
; CHECK-NEXT: store <16 x i8> [[TMP8]], ptr [[GEP_S0]], align 1
625+
; CHECK-NEXT: [[TMP1:%.*]] = call <4 x i32> @llvm.experimental.vp.strided.load.v4i32.p0.i64(ptr align 16 [[GEP_L0]], i64 100, <4 x i1> splat (i1 true), i32 4)
626+
; CHECK-NEXT: [[TMP8:%.*]] = bitcast <4 x i32> [[TMP1]] to <16 x i8>
627+
; CHECK-NEXT: store <16 x i8> [[TMP8]], ptr [[GEP_S0]], align 16
640628
; CHECK-NEXT: ret void
641629
;
642630
%gep_l0 = getelementptr inbounds i8, ptr %pl, i64 0

0 commit comments

Comments
 (0)