Skip to content

Commit 6ec2679

Browse files
author
Mikhail Gudim
committed
[SLPVectorizer] Widen constant strided loads.
Given a set of pointers, check if they can be rearranged as follows (%s is a constant): %b + 0 * %s + 0 %b + 0 * %s + 1 %b + 0 * %s + 2 ... %b + 0 * %s + w %b + 1 * %s + 0 %b + 1 * %s + 1 %b + 1 * %s + 2 ... %b + 1 * %s + w ... If the pointers can be rearanged in the above pattern, it means that the memory can be accessed with a strided loads of width `w` and stride `%s`.
1 parent 7e7c923 commit 6ec2679

File tree

2 files changed

+134
-52
lines changed

2 files changed

+134
-52
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 131 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2242,8 +2242,29 @@ class BoUpSLP {
22422242
/// may not be necessary.
22432243
bool isLoadCombineCandidate(ArrayRef<Value *> Stores) const;
22442244
bool isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
2245-
Align Alignment, const int64_t Diff, Value *Ptr0,
2246-
Value *PtrN, StridedPtrInfo &SPtrInfo) const;
2245+
Align Alignment, int64_t Diff, size_t VecSz) const;
2246+
/// Given a set of pointers, check if they can be rearranged as follows (%s is
2247+
/// a constant):
2248+
/// %b + 0 * %s + 0
2249+
/// %b + 0 * %s + 1
2250+
/// %b + 0 * %s + 2
2251+
/// ...
2252+
/// %b + 0 * %s + w
2253+
///
2254+
/// %b + 1 * %s + 0
2255+
/// %b + 1 * %s + 1
2256+
/// %b + 1 * %s + 2
2257+
/// ...
2258+
/// %b + 1 * %s + w
2259+
/// ...
2260+
///
2261+
/// If the pointers can be rearanged in the above pattern, it means that the
2262+
/// memory can be accessed with a strided loads of width `w` and stride `%s`.
2263+
bool analyzeConstantStrideCandidate(ArrayRef<Value *> PointerOps,
2264+
Type *ElemTy, Align CommonAlignment,
2265+
SmallVectorImpl<unsigned> &SortedIndices,
2266+
int64_t Diff, Value *Ptr0, Value *PtrN,
2267+
StridedPtrInfo &SPtrInfo) const;
22472268

22482269
/// Return true if an array of scalar loads can be replaced with a strided
22492270
/// load (with run-time stride).
@@ -6844,12 +6865,7 @@ isMaskedLoadCompress(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
68446865
/// current graph (for masked gathers extra extractelement instructions
68456866
/// might be required).
68466867
bool BoUpSLP::isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
6847-
Align Alignment, const int64_t Diff, Value *Ptr0,
6848-
Value *PtrN, StridedPtrInfo &SPtrInfo) const {
6849-
const size_t Sz = PointerOps.size();
6850-
if (Diff % (Sz - 1) != 0)
6851-
return false;
6852-
6868+
Align Alignment, int64_t Diff, size_t VecSz) const {
68536869
// Try to generate strided load node.
68546870
auto IsAnyPointerUsedOutGraph = any_of(PointerOps, [&](Value *V) {
68556871
return isa<Instruction>(V) && any_of(V->users(), [&](User *U) {
@@ -6858,41 +6874,109 @@ bool BoUpSLP::isStridedLoad(ArrayRef<Value *> PointerOps, Type *ScalarTy,
68586874
});
68596875

68606876
const uint64_t AbsoluteDiff = std::abs(Diff);
6861-
auto *VecTy = getWidenedType(ScalarTy, Sz);
6877+
auto *VecTy = getWidenedType(ScalarTy, VecSz);
68626878
if (IsAnyPointerUsedOutGraph ||
6863-
(AbsoluteDiff > Sz &&
6864-
(Sz > MinProfitableStridedLoads ||
6865-
(AbsoluteDiff <= MaxProfitableLoadStride * Sz &&
6866-
AbsoluteDiff % Sz == 0 && has_single_bit(AbsoluteDiff / Sz)))) ||
6867-
Diff == -(static_cast<int64_t>(Sz) - 1)) {
6868-
int64_t Stride = Diff / static_cast<int64_t>(Sz - 1);
6869-
if (Diff != Stride * static_cast<int64_t>(Sz - 1))
6879+
(AbsoluteDiff > VecSz &&
6880+
(VecSz > MinProfitableStridedLoads ||
6881+
(AbsoluteDiff <= MaxProfitableLoadStride * VecSz &&
6882+
AbsoluteDiff % VecSz == 0 && has_single_bit(AbsoluteDiff / VecSz)))) ||
6883+
Diff == -(static_cast<int64_t>(VecSz) - 1)) {
6884+
int64_t Stride = Diff / static_cast<int64_t>(VecSz - 1);
6885+
if (Diff != Stride * static_cast<int64_t>(VecSz - 1))
68706886
return false;
68716887
if (!TTI->isLegalStridedLoadStore(VecTy, Alignment))
68726888
return false;
6889+
}
6890+
return true;
6891+
}
6892+
6893+
bool BoUpSLP::analyzeConstantStrideCandidate(
6894+
ArrayRef<Value *> PointerOps, Type *ElemTy, Align CommonAlignment,
6895+
SmallVectorImpl<unsigned> &SortedIndices, int64_t Diff, Value *Ptr0,
6896+
Value *PtrN, StridedPtrInfo &SPtrInfo) const {
6897+
const unsigned Sz = PointerOps.size();
6898+
SmallVector<int64_t> SortedOffsetsFromBase;
6899+
SortedOffsetsFromBase.resize(Sz);
6900+
for (unsigned I : seq<unsigned>(Sz)) {
6901+
Value *Ptr =
6902+
SortedIndices.empty() ? PointerOps[I] : PointerOps[SortedIndices[I]];
6903+
SortedOffsetsFromBase[I] =
6904+
*getPointersDiff(ElemTy, Ptr0, ElemTy, Ptr, *DL, *SE);
6905+
}
6906+
assert(SortedOffsetsFromBase.size() > 1 &&
6907+
"Trying to generate strided load for less than 2 loads");
6908+
//
6909+
// Find where the first group ends.
6910+
int64_t StrideWithinGroup =
6911+
SortedOffsetsFromBase[1] - SortedOffsetsFromBase[0];
6912+
unsigned GroupSize = 1;
6913+
for (; GroupSize != SortedOffsetsFromBase.size(); ++GroupSize) {
6914+
if (SortedOffsetsFromBase[GroupSize] -
6915+
SortedOffsetsFromBase[GroupSize - 1] !=
6916+
StrideWithinGroup)
6917+
break;
6918+
}
6919+
unsigned VecSz = Sz;
6920+
Type *ScalarTy = ElemTy;
6921+
int64_t StrideIntVal = StrideWithinGroup;
6922+
FixedVectorType *StridedLoadTy = getWidenedType(ScalarTy, VecSz);
6923+
6924+
if (Sz != GroupSize) {
6925+
if (Sz % GroupSize != 0)
6926+
return false;
6927+
VecSz = Sz / GroupSize;
6928+
6929+
if (StrideWithinGroup != 1)
6930+
return false;
6931+
unsigned VecSz = Sz / GroupSize;
6932+
ScalarTy = Type::getIntNTy(SE->getContext(),
6933+
DL->getTypeSizeInBits(ElemTy).getFixedValue() *
6934+
GroupSize);
6935+
StridedLoadTy = getWidenedType(ScalarTy, VecSz);
6936+
if (!TTI->isTypeLegal(StridedLoadTy) ||
6937+
!TTI->isLegalStridedLoadStore(StridedLoadTy, CommonAlignment))
6938+
return false;
68736939

6874-
// Iterate through all pointers and check if all distances are
6875-
// unique multiple of Dist.
6876-
SmallSet<int64_t, 4> Dists;
6877-
for (Value *Ptr : PointerOps) {
6878-
int64_t Dist = 0;
6879-
if (Ptr == PtrN)
6880-
Dist = Diff;
6881-
else if (Ptr != Ptr0)
6882-
Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, *DL, *SE);
6883-
// If the strides are not the same or repeated, we can't
6884-
// vectorize.
6885-
if (((Dist / Stride) * Stride) != Dist || !Dists.insert(Dist).second)
6940+
unsigned PrevGroupStartIdx = 0;
6941+
unsigned CurrentGroupStartIdx = GroupSize;
6942+
int64_t StrideBetweenGroups =
6943+
SortedOffsetsFromBase[GroupSize] - SortedOffsetsFromBase[0];
6944+
StrideIntVal = StrideBetweenGroups;
6945+
while (CurrentGroupStartIdx != Sz) {
6946+
if (SortedOffsetsFromBase[CurrentGroupStartIdx] -
6947+
SortedOffsetsFromBase[PrevGroupStartIdx] !=
6948+
StrideBetweenGroups)
68866949
break;
6950+
PrevGroupStartIdx = CurrentGroupStartIdx;
6951+
CurrentGroupStartIdx += GroupSize;
68876952
}
6888-
if (Dists.size() == Sz) {
6889-
Type *StrideTy = DL->getIndexType(Ptr0->getType());
6890-
SPtrInfo.StrideVal = ConstantInt::get(StrideTy, Stride);
6891-
SPtrInfo.Ty = getWidenedType(ScalarTy, Sz);
6892-
return true;
6953+
if (CurrentGroupStartIdx != Sz)
6954+
return false;
6955+
6956+
auto CheckGroup = [&](unsigned StartIdx, unsigned GroupSize0,
6957+
int64_t StrideWithinGroup) -> bool {
6958+
unsigned GroupEndIdx = StartIdx + 1;
6959+
for (; GroupEndIdx != Sz; ++GroupEndIdx) {
6960+
if (SortedOffsetsFromBase[GroupEndIdx] -
6961+
SortedOffsetsFromBase[GroupEndIdx - 1] !=
6962+
StrideWithinGroup)
6963+
break;
6964+
}
6965+
return GroupEndIdx - StartIdx == GroupSize0;
6966+
};
6967+
for (unsigned I = 0; I < Sz; I += GroupSize) {
6968+
if (!CheckGroup(I, GroupSize, StrideWithinGroup))
6969+
return false;
68936970
}
68946971
}
6895-
return false;
6972+
6973+
if (!isStridedLoad(PointerOps, ScalarTy, CommonAlignment, Diff, VecSz))
6974+
return false;
6975+
6976+
Type *StrideTy = DL->getIndexType(Ptr0->getType());
6977+
SPtrInfo.StrideVal = ConstantInt::get(StrideTy, StrideIntVal);
6978+
SPtrInfo.Ty = StridedLoadTy;
6979+
return true;
68966980
}
68976981

68986982
bool BoUpSLP::analyzeRtStrideCandidate(ArrayRef<Value *> PointerOps,
@@ -6990,8 +7074,8 @@ BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads(
69907074
Align Alignment =
69917075
cast<LoadInst>(Order.empty() ? VL.front() : VL[Order.front()])
69927076
->getAlign();
6993-
if (isStridedLoad(PointerOps, ScalarTy, Alignment, *Diff, Ptr0, PtrN,
6994-
SPtrInfo))
7077+
if (analyzeConstantStrideCandidate(PointerOps, ScalarTy, Alignment, Order,
7078+
*Diff, Ptr0, PtrN, SPtrInfo))
69957079
return LoadsState::StridedVectorize;
69967080
}
69977081
if (!TTI->isLegalMaskedGather(VecTy, CommonAlignment) ||
@@ -14902,11 +14986,19 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
1490214986
}
1490314987
break;
1490414988
case TreeEntry::StridedVectorize: {
14989+
const StridedPtrInfo &SPtrInfo = TreeEntryToStridedPtrInfoMap.at(E);
14990+
FixedVectorType *StridedLoadTy = SPtrInfo.Ty;
14991+
assert(StridedLoadTy && "Missing StridedPoinerInfo for tree entry.");
1490514992
Align CommonAlignment =
1490614993
computeCommonAlignment<LoadInst>(UniqueValues.getArrayRef());
1490714994
VecLdCost = TTI->getStridedMemoryOpCost(
14908-
Instruction::Load, VecTy, LI0->getPointerOperand(),
14995+
Instruction::Load, StridedLoadTy, LI0->getPointerOperand(),
1490914996
/*VariableMask=*/false, CommonAlignment, CostKind);
14997+
if (StridedLoadTy != VecTy)
14998+
VecLdCost +=
14999+
TTI->getCastInstrCost(Instruction::BitCast, VecTy, StridedLoadTy,
15000+
getCastContextHint(*E), CostKind);
15001+
1491015002
break;
1491115003
}
1491215004
case TreeEntry::CompressVectorize: {
@@ -19670,6 +19762,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
1967019762
? NewLI
1967119763
: ::propagateMetadata(NewLI, E->Scalars);
1967219764

19765+
if (StridedLoadTy != VecTy)
19766+
V = Builder.CreateBitOrPointerCast(V, VecTy);
1967319767
V = FinalShuffle(V, E);
1967419768
E->VectorizedValue = V;
1967519769
++NumVectorInstructions;

llvm/test/Transforms/SLPVectorizer/RISCV/basic-strided-loads.ll

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -621,22 +621,10 @@ define void @constant_stride_widen_no_reordering(ptr %pl, i64 %stride, ptr %ps)
621621
; CHECK-LABEL: define void @constant_stride_widen_no_reordering(
622622
; CHECK-SAME: ptr [[PL:%.*]], i64 [[STRIDE:%.*]], ptr [[PS:%.*]]) #[[ATTR0]] {
623623
; CHECK-NEXT: [[GEP_L0:%.*]] = getelementptr inbounds i8, ptr [[PL]], i64 0
624-
; CHECK-NEXT: [[GEP_L4:%.*]] = getelementptr inbounds i8, ptr [[PL]], i64 100
625-
; CHECK-NEXT: [[GEP_L8:%.*]] = getelementptr inbounds i8, ptr [[PL]], i64 200
626-
; CHECK-NEXT: [[GEP_L12:%.*]] = getelementptr inbounds i8, ptr [[PL]], i64 300
627624
; CHECK-NEXT: [[GEP_S0:%.*]] = getelementptr inbounds i8, ptr [[PS]], i64 0
628-
; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i8>, ptr [[GEP_L0]], align 1
629-
; CHECK-NEXT: [[TMP2:%.*]] = load <4 x i8>, ptr [[GEP_L4]], align 1
630-
; CHECK-NEXT: [[TMP3:%.*]] = load <4 x i8>, ptr [[GEP_L8]], align 1
631-
; CHECK-NEXT: [[TMP4:%.*]] = load <4 x i8>, ptr [[GEP_L12]], align 1
632-
; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x i8> [[TMP1]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
633-
; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <4 x i8> [[TMP2]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
634-
; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <4 x i8> [[TMP1]], <4 x i8> [[TMP2]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
635-
; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <4 x i8> [[TMP3]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
636-
; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <16 x i8> [[TMP7]], <16 x i8> [[TMP11]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 16, i32 17, i32 18, i32 19, i32 poison, i32 poison, i32 poison, i32 poison>
637-
; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <4 x i8> [[TMP4]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
638-
; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <16 x i8> [[TMP9]], <16 x i8> [[TMP10]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 16, i32 17, i32 18, i32 19>
639-
; CHECK-NEXT: store <16 x i8> [[TMP8]], ptr [[GEP_S0]], align 1
625+
; CHECK-NEXT: [[TMP1:%.*]] = call <4 x i32> @llvm.experimental.vp.strided.load.v4i32.p0.i64(ptr align 16 [[GEP_L0]], i64 100, <4 x i1> splat (i1 true), i32 4)
626+
; CHECK-NEXT: [[TMP8:%.*]] = bitcast <4 x i32> [[TMP1]] to <16 x i8>
627+
; CHECK-NEXT: store <16 x i8> [[TMP8]], ptr [[GEP_S0]], align 16
640628
; CHECK-NEXT: ret void
641629
;
642630
%gep_l0 = getelementptr inbounds i8, ptr %pl, i64 0

0 commit comments

Comments
 (0)