Skip to content

Commit 12131d5

Browse files
authored
[SLPVectorizer] Widen constant strided loads. (#162324)
Given a set of pointers, check if they can be rearranged as follows (%s is a constant): %b + 0 * %s + 0 %b + 0 * %s + 1 %b + 0 * %s + 2 ... %b + 0 * %s + w %b + 1 * %s + 0 %b + 1 * %s + 1 %b + 1 * %s + 2 ... %b + 1 * %s + w ... If the pointers can be rearanged in the above pattern, it means that the memory can be accessed with a strided loads of width `w` and stride `%s`.
1 parent e148d2d commit 12131d5

File tree

2 files changed

+97
-39
lines changed

2 files changed

+97
-39
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 94 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2245,7 +2245,6 @@ class slpvectorizer::BoUpSLP {
22452245
/// Return true if an array of scalar loads can be replaced with a strided
22462246
/// load (with constant stride).
22472247
///
2248-
/// TODO:
22492248
/// It is possible that the load gets "widened". Suppose that originally each
22502249
/// load loads `k` bytes and `PointerOps` can be arranged as follows (`%s` is
22512250
/// constant): %b + 0 * %s + 0 %b + 0 * %s + 1 %b + 0 * %s + 2
@@ -7008,32 +7007,93 @@ bool BoUpSLP::analyzeConstantStrideCandidate(
70087007
const SmallVectorImpl<unsigned> &SortedIndices, const int64_t Diff,
70097008
Value *Ptr0, Value *PtrN, StridedPtrInfo &SPtrInfo) const {
70107009
const size_t Sz = PointerOps.size();
7011-
if (!isStridedLoad(PointerOps, ScalarTy, Alignment, Diff, Sz))
7010+
SmallVector<int64_t> SortedOffsetsFromBase(Sz);
7011+
// Go through `PointerOps` in sorted order and record offsets from `Ptr0`.
7012+
for (unsigned I : seq<unsigned>(Sz)) {
7013+
Value *Ptr =
7014+
SortedIndices.empty() ? PointerOps[I] : PointerOps[SortedIndices[I]];
7015+
SortedOffsetsFromBase[I] =
7016+
*getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, *DL, *SE);
7017+
}
7018+
7019+
// The code below checks that `SortedOffsetsFromBase` looks as follows:
7020+
// ```
7021+
// [
7022+
// (e_{0, 0}, e_{0, 1}, ..., e_{0, GroupSize - 1}), // first group
7023+
// (e_{1, 0}, e_{1, 1}, ..., e_{1, GroupSize - 1}), // secon group
7024+
// ...
7025+
// (e_{NumGroups - 1, 0}, e_{NumGroups - 1, 1}, ..., e_{NumGroups - 1,
7026+
// GroupSize - 1}), // last group
7027+
// ]
7028+
// ```
7029+
// The distance between consecutive elements within each group should all be
7030+
// the same `StrideWithinGroup`. The distance between the first elements of
7031+
// consecutive groups should all be the same `StrideBetweenGroups`.
7032+
7033+
int64_t StrideWithinGroup =
7034+
SortedOffsetsFromBase[1] - SortedOffsetsFromBase[0];
7035+
// Determine size of the first group. Later we will check that all other
7036+
// groups have the same size.
7037+
auto IsEndOfGroupIndex = [=, &SortedOffsetsFromBase](unsigned Idx) {
7038+
return SortedOffsetsFromBase[Idx] - SortedOffsetsFromBase[Idx - 1] !=
7039+
StrideWithinGroup;
7040+
};
7041+
auto Indices = seq<unsigned>(1, Sz);
7042+
auto FoundIt = llvm::find_if(Indices, IsEndOfGroupIndex);
7043+
unsigned GroupSize = FoundIt != Indices.end() ? *FoundIt : Sz;
7044+
7045+
unsigned VecSz = Sz;
7046+
Type *NewScalarTy = ScalarTy;
7047+
7048+
// Quick detour: at this point we can say what the type of strided load would
7049+
// be if all the checks pass. Check if this type is legal for the target.
7050+
bool NeedsWidening = Sz != GroupSize;
7051+
if (NeedsWidening) {
7052+
if (Sz % GroupSize != 0)
7053+
return false;
7054+
7055+
if (StrideWithinGroup != 1)
7056+
return false;
7057+
VecSz = Sz / GroupSize;
7058+
NewScalarTy = Type::getIntNTy(
7059+
SE->getContext(),
7060+
DL->getTypeSizeInBits(ScalarTy).getFixedValue() * GroupSize);
7061+
}
7062+
7063+
if (!isStridedLoad(PointerOps, NewScalarTy, Alignment, Diff, VecSz))
70127064
return false;
70137065

7014-
int64_t Stride = Diff / static_cast<int64_t>(Sz - 1);
7066+
int64_t StrideIntVal = StrideWithinGroup;
7067+
if (NeedsWidening) {
7068+
// Continue with checking the "shape" of `SortedOffsetsFromBase`.
7069+
// Check that the strides between groups are all the same.
7070+
unsigned CurrentGroupStartIdx = GroupSize;
7071+
int64_t StrideBetweenGroups =
7072+
SortedOffsetsFromBase[GroupSize] - SortedOffsetsFromBase[0];
7073+
StrideIntVal = StrideBetweenGroups;
7074+
for (; CurrentGroupStartIdx < Sz; CurrentGroupStartIdx += GroupSize) {
7075+
if (SortedOffsetsFromBase[CurrentGroupStartIdx] -
7076+
SortedOffsetsFromBase[CurrentGroupStartIdx - GroupSize] !=
7077+
StrideBetweenGroups)
7078+
return false;
7079+
}
70157080

7016-
// Iterate through all pointers and check if all distances are
7017-
// unique multiple of Dist.
7018-
SmallSet<int64_t, 4> Dists;
7019-
for (Value *Ptr : PointerOps) {
7020-
int64_t Dist = 0;
7021-
if (Ptr == PtrN)
7022-
Dist = Diff;
7023-
else if (Ptr != Ptr0)
7024-
Dist = *getPointersDiff(ScalarTy, Ptr0, ScalarTy, Ptr, *DL, *SE);
7025-
// If the strides are not the same or repeated, we can't
7026-
// vectorize.
7027-
if (((Dist / Stride) * Stride) != Dist || !Dists.insert(Dist).second)
7028-
break;
7029-
}
7030-
if (Dists.size() == Sz) {
7031-
Type *StrideTy = DL->getIndexType(Ptr0->getType());
7032-
SPtrInfo.StrideVal = ConstantInt::get(StrideTy, Stride);
7033-
SPtrInfo.Ty = getWidenedType(ScalarTy, Sz);
7034-
return true;
7081+
auto CheckGroup = [=](const unsigned StartIdx) -> bool {
7082+
auto Indices = seq<unsigned>(StartIdx + 1, Sz);
7083+
auto FoundIt = llvm::find_if(Indices, IsEndOfGroupIndex);
7084+
unsigned GroupEndIdx = FoundIt != Indices.end() ? *FoundIt : Sz;
7085+
return GroupEndIdx - StartIdx == GroupSize;
7086+
};
7087+
for (unsigned I = 0; I < Sz; I += GroupSize) {
7088+
if (!CheckGroup(I))
7089+
return false;
7090+
}
70357091
}
7036-
return false;
7092+
7093+
Type *StrideTy = DL->getIndexType(Ptr0->getType());
7094+
SPtrInfo.StrideVal = ConstantInt::get(StrideTy, StrideIntVal);
7095+
SPtrInfo.Ty = getWidenedType(NewScalarTy, VecSz);
7096+
return true;
70377097
}
70387098

70397099
bool BoUpSLP::analyzeRtStrideCandidate(ArrayRef<Value *> PointerOps,
@@ -15061,11 +15121,19 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
1506115121
}
1506215122
break;
1506315123
case TreeEntry::StridedVectorize: {
15124+
const StridedPtrInfo &SPtrInfo = TreeEntryToStridedPtrInfoMap.at(E);
15125+
FixedVectorType *StridedLoadTy = SPtrInfo.Ty;
15126+
assert(StridedLoadTy && "Missing StridedPoinerInfo for tree entry.");
1506415127
Align CommonAlignment =
1506515128
computeCommonAlignment<LoadInst>(UniqueValues.getArrayRef());
1506615129
VecLdCost = TTI->getStridedMemoryOpCost(
15067-
Instruction::Load, VecTy, LI0->getPointerOperand(),
15130+
Instruction::Load, StridedLoadTy, LI0->getPointerOperand(),
1506815131
/*VariableMask=*/false, CommonAlignment, CostKind);
15132+
if (StridedLoadTy != VecTy)
15133+
VecLdCost +=
15134+
TTI->getCastInstrCost(Instruction::BitCast, VecTy, StridedLoadTy,
15135+
getCastContextHint(*E), CostKind);
15136+
1506915137
break;
1507015138
}
1507115139
case TreeEntry::CompressVectorize: {
@@ -19870,6 +19938,8 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
1987019938
? NewLI
1987119939
: ::propagateMetadata(NewLI, E->Scalars);
1987219940

19941+
if (StridedLoadTy != VecTy)
19942+
V = Builder.CreateBitOrPointerCast(V, VecTy);
1987319943
V = FinalShuffle(V, E);
1987419944
E->VectorizedValue = V;
1987519945
++NumVectorInstructions;

llvm/test/Transforms/SLPVectorizer/RISCV/basic-strided-loads.ll

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
22

3-
; RUN: opt -mtriple=riscv64 -mattr=+m,+v -passes=slp-vectorizer -S < %s | FileCheck %s
3+
; RUN: opt -mtriple=riscv64 -mattr=+m,+v,+unaligned-vector-mem -passes=slp-vectorizer -S < %s | FileCheck %s
44

55
define void @const_stride_1_no_reordering(ptr %pl, ptr %ps) {
66
; CHECK-LABEL: define void @const_stride_1_no_reordering(
@@ -621,21 +621,9 @@ define void @constant_stride_widen_no_reordering(ptr %pl, i64 %stride, ptr %ps)
621621
; CHECK-LABEL: define void @constant_stride_widen_no_reordering(
622622
; CHECK-SAME: ptr [[PL:%.*]], i64 [[STRIDE:%.*]], ptr [[PS:%.*]]) #[[ATTR0]] {
623623
; CHECK-NEXT: [[GEP_L0:%.*]] = getelementptr inbounds i8, ptr [[PL]], i64 0
624-
; CHECK-NEXT: [[GEP_L4:%.*]] = getelementptr inbounds i8, ptr [[PL]], i64 100
625-
; CHECK-NEXT: [[GEP_L8:%.*]] = getelementptr inbounds i8, ptr [[PL]], i64 200
626-
; CHECK-NEXT: [[GEP_L12:%.*]] = getelementptr inbounds i8, ptr [[PL]], i64 300
627624
; CHECK-NEXT: [[GEP_S0:%.*]] = getelementptr inbounds i8, ptr [[PS]], i64 0
628-
; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i8>, ptr [[GEP_L0]], align 1
629-
; CHECK-NEXT: [[TMP2:%.*]] = load <4 x i8>, ptr [[GEP_L4]], align 1
630-
; CHECK-NEXT: [[TMP3:%.*]] = load <4 x i8>, ptr [[GEP_L8]], align 1
631-
; CHECK-NEXT: [[TMP4:%.*]] = load <4 x i8>, ptr [[GEP_L12]], align 1
632-
; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x i8> [[TMP1]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
633-
; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <4 x i8> [[TMP2]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
634-
; CHECK-NEXT: [[TMP7:%.*]] = shufflevector <4 x i8> [[TMP1]], <4 x i8> [[TMP2]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
635-
; CHECK-NEXT: [[TMP11:%.*]] = shufflevector <4 x i8> [[TMP3]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
636-
; CHECK-NEXT: [[TMP9:%.*]] = shufflevector <16 x i8> [[TMP7]], <16 x i8> [[TMP11]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 16, i32 17, i32 18, i32 19, i32 poison, i32 poison, i32 poison, i32 poison>
637-
; CHECK-NEXT: [[TMP10:%.*]] = shufflevector <4 x i8> [[TMP4]], <4 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
638-
; CHECK-NEXT: [[TMP8:%.*]] = shufflevector <16 x i8> [[TMP9]], <16 x i8> [[TMP10]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 16, i32 17, i32 18, i32 19>
625+
; CHECK-NEXT: [[TMP1:%.*]] = call <4 x i32> @llvm.experimental.vp.strided.load.v4i32.p0.i64(ptr align 1 [[GEP_L0]], i64 100, <4 x i1> splat (i1 true), i32 4)
626+
; CHECK-NEXT: [[TMP8:%.*]] = bitcast <4 x i32> [[TMP1]] to <16 x i8>
639627
; CHECK-NEXT: store <16 x i8> [[TMP8]], ptr [[GEP_S0]], align 1
640628
; CHECK-NEXT: ret void
641629
;

0 commit comments

Comments
 (0)