Skip to content

Commit 1802d76

Browse files
mgudimgithub-actions[bot]
authored andcommitted
Automerge: [SLPVectorizer][NFC] Save stride in a map. (#157706)
In order to avoid recalculating stride of strided load twice save it in a map.
2 parents 9cda423 + 66a8f47 commit 1802d76

File tree

1 file changed

+103
-68
lines changed

1 file changed

+103
-68
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 103 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1926,6 +1926,19 @@ class BoUpSLP {
19261926
class ShuffleCostEstimator;
19271927
class ShuffleInstructionBuilder;
19281928

1929+
/// If we decide to generate strided load / store, this struct contains all
1930+
/// the necessary info. It's fields are calculated by analyzeRtStrideCandidate
1931+
/// and analyzeConstantStrideCandidate. Note that Stride can be given either
1932+
/// as a SCEV or as a Value if it already exists. To get the stride in bytes,
1933+
/// StrideVal (or value obtained from StrideSCEV) has to by multiplied by the
1934+
/// size of element of FixedVectorType.
1935+
struct StridedPtrInfo {
1936+
Value *StrideVal = nullptr;
1937+
const SCEV *StrideSCEV = nullptr;
1938+
FixedVectorType *Ty = nullptr;
1939+
};
1940+
SmallDenseMap<TreeEntry *, StridedPtrInfo> TreeEntryToStridedPtrInfoMap;
1941+
19291942
public:
19301943
/// Tracks the state we can represent the loads in the given sequence.
19311944
enum class LoadsState {
@@ -2221,6 +2234,11 @@ class BoUpSLP {
22212234
/// TODO: If load combining is allowed in the IR optimizer, this analysis
22222235
/// may not be necessary.
22232236
bool isLoadCombineCandidate(ArrayRef<Value *> Stores) const;
2237+
bool isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
2238+
ArrayRef<unsigned> Order, const TargetTransformInfo &TTI,
2239+
const DataLayout &DL, ScalarEvolution &SE,
2240+
const bool IsAnyPointerUsedOutGraph, const int64_t Diff,
2241+
StridedPtrInfo &SPtrInfo) const;
22242242

22252243
/// Checks if the given array of loads can be represented as a vectorized,
22262244
/// scatter or just simple gather.
@@ -2235,6 +2253,7 @@ class BoUpSLP {
22352253
LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
22362254
SmallVectorImpl<unsigned> &Order,
22372255
SmallVectorImpl<Value *> &PointerOps,
2256+
StridedPtrInfo &SPtrInfo,
22382257
unsigned *BestVF = nullptr,
22392258
bool TryRecursiveCheck = true) const;
22402259

@@ -4479,11 +4498,10 @@ class BoUpSLP {
44794498

44804499
/// Checks if the specified list of the instructions/values can be vectorized
44814500
/// and fills required data before actual scheduling of the instructions.
4482-
TreeEntry::EntryState
4483-
getScalarsVectorizationState(const InstructionsState &S, ArrayRef<Value *> VL,
4484-
bool IsScatterVectorizeUserTE,
4485-
OrdersType &CurrentOrder,
4486-
SmallVectorImpl<Value *> &PointerOps);
4501+
TreeEntry::EntryState getScalarsVectorizationState(
4502+
const InstructionsState &S, ArrayRef<Value *> VL,
4503+
bool IsScatterVectorizeUserTE, OrdersType &CurrentOrder,
4504+
SmallVectorImpl<Value *> &PointerOps, StridedPtrInfo &SPtrInfo);
44874505

44884506
/// Maps a specific scalar to its tree entry(ies).
44894507
SmallDenseMap<Value *, SmallVector<TreeEntry *>> ScalarToTreeEntries;
@@ -6800,12 +6818,13 @@ isMaskedLoadCompress(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
68006818
/// 4. Any pointer operand is an instruction with the users outside of the
68016819
/// current graph (for masked gathers extra extractelement instructions
68026820
/// might be required).
6803-
static bool isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
6804-
ArrayRef<unsigned> Order,
6805-
const TargetTransformInfo &TTI, const DataLayout &DL,
6806-
ScalarEvolution &SE,
6807-
const bool IsAnyPointerUsedOutGraph,
6808-
const int64_t Diff) {
6821+
bool BoUpSLP::isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
6822+
ArrayRef<unsigned> Order,
6823+
const TargetTransformInfo &TTI,
6824+
const DataLayout &DL, ScalarEvolution &SE,
6825+
const bool IsAnyPointerUsedOutGraph,
6826+
const int64_t Diff,
6827+
StridedPtrInfo &SPtrInfo) const {
68096828
const size_t Sz = VL.size();
68106829
const uint64_t AbsoluteDiff = std::abs(Diff);
68116830
Type *ScalarTy = VL.front()->getType();
@@ -6847,17 +6866,20 @@ static bool isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
68476866
if (((Dist / Stride) * Stride) != Dist || !Dists.insert(Dist).second)
68486867
break;
68496868
}
6850-
if (Dists.size() == Sz)
6869+
if (Dists.size() == Sz) {
6870+
Type *StrideTy = DL.getIndexType(Ptr0->getType());
6871+
SPtrInfo.StrideVal = ConstantInt::get(StrideTy, Stride);
6872+
SPtrInfo.Ty = getWidenedType(ScalarTy, Sz);
68516873
return true;
6874+
}
68526875
}
68536876
return false;
68546877
}
68556878

6856-
BoUpSLP::LoadsState
6857-
BoUpSLP::canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
6858-
SmallVectorImpl<unsigned> &Order,
6859-
SmallVectorImpl<Value *> &PointerOps,
6860-
unsigned *BestVF, bool TryRecursiveCheck) const {
6879+
BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads(
6880+
ArrayRef<Value *> VL, const Value *VL0, SmallVectorImpl<unsigned> &Order,
6881+
SmallVectorImpl<Value *> &PointerOps, StridedPtrInfo &SPtrInfo,
6882+
unsigned *BestVF, bool TryRecursiveCheck) const {
68616883
// Check that a vectorized load would load the same memory as a scalar
68626884
// load. For example, we don't want to vectorize loads that are smaller
68636885
// than 8-bit. Even though we have a packed struct {<i2, i2, i2, i2>} LLVM
@@ -6895,9 +6917,13 @@ BoUpSLP::canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
68956917
Align CommonAlignment = computeCommonAlignment<LoadInst>(VL);
68966918
if (!IsSorted) {
68976919
if (Sz > MinProfitableStridedLoads && TTI->isTypeLegal(VecTy)) {
6898-
if (TTI->isLegalStridedLoadStore(VecTy, CommonAlignment) &&
6899-
calculateRtStride(PointerOps, ScalarTy, *DL, *SE, Order))
6920+
if (const SCEV *Stride =
6921+
calculateRtStride(PointerOps, ScalarTy, *DL, *SE, Order);
6922+
Stride && TTI->isLegalStridedLoadStore(VecTy, CommonAlignment)) {
6923+
SPtrInfo.Ty = getWidenedType(ScalarTy, PointerOps.size());
6924+
SPtrInfo.StrideSCEV = Stride;
69006925
return LoadsState::StridedVectorize;
6926+
}
69016927
}
69026928

69036929
if (!TTI->isLegalMaskedGather(VecTy, CommonAlignment) ||
@@ -6941,7 +6967,7 @@ BoUpSLP::canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
69416967
});
69426968
if (IsPossibleStrided &&
69436969
isStridedLoad(VL, PointerOps, Order, *TTI, *DL, *SE,
6944-
IsAnyPointerUsedOutGraph, *Diff))
6970+
IsAnyPointerUsedOutGraph, *Diff, SPtrInfo))
69456971
return LoadsState::StridedVectorize;
69466972
}
69476973
if (!TTI->isLegalMaskedGather(VecTy, CommonAlignment) ||
@@ -7025,9 +7051,9 @@ BoUpSLP::canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
70257051
ArrayRef<Value *> Slice = VL.slice(Cnt, VF);
70267052
SmallVector<unsigned> Order;
70277053
SmallVector<Value *> PointerOps;
7028-
LoadsState LS =
7029-
canVectorizeLoads(Slice, Slice.front(), Order, PointerOps, BestVF,
7030-
/*TryRecursiveCheck=*/false);
7054+
LoadsState LS = canVectorizeLoads(Slice, Slice.front(), Order,
7055+
PointerOps, SPtrInfo, BestVF,
7056+
/*TryRecursiveCheck=*/false);
70317057
// Check that the sorted loads are consecutive.
70327058
if (LS == LoadsState::Gather) {
70337059
if (BestVF) {
@@ -7699,9 +7725,10 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom,
76997725
// extra analysis later, so include such nodes into a special list.
77007726
if (TE.hasState() && TE.getOpcode() == Instruction::Load) {
77017727
SmallVector<Value *> PointerOps;
7728+
StridedPtrInfo SPtrInfo;
77027729
OrdersType CurrentOrder;
77037730
LoadsState Res = canVectorizeLoads(TE.Scalars, TE.Scalars.front(),
7704-
CurrentOrder, PointerOps);
7731+
CurrentOrder, PointerOps, SPtrInfo);
77057732
if (Res == LoadsState::Vectorize || Res == LoadsState::StridedVectorize ||
77067733
Res == LoadsState::CompressVectorize)
77077734
return std::move(CurrentOrder);
@@ -9207,8 +9234,9 @@ void BoUpSLP::tryToVectorizeGatheredLoads(
92079234
// Try to build vector load.
92089235
ArrayRef<Value *> Values(
92099236
reinterpret_cast<Value *const *>(Slice.begin()), Slice.size());
9237+
StridedPtrInfo SPtrInfo;
92109238
LoadsState LS = canVectorizeLoads(Values, Slice.front(), CurrentOrder,
9211-
PointerOps, &BestVF);
9239+
PointerOps, SPtrInfo, &BestVF);
92129240
if (LS != LoadsState::Gather ||
92139241
(BestVF > 1 && static_cast<unsigned>(NumElts) == 2 * BestVF)) {
92149242
if (LS == LoadsState::ScatterVectorize) {
@@ -9402,6 +9430,7 @@ void BoUpSLP::tryToVectorizeGatheredLoads(
94029430
unsigned VF = *CommonVF;
94039431
OrdersType Order;
94049432
SmallVector<Value *> PointerOps;
9433+
StridedPtrInfo SPtrInfo;
94059434
// Segmented load detected - vectorize at maximum vector factor.
94069435
if (InterleaveFactor <= Slice.size() &&
94079436
TTI.isLegalInterleavedAccessType(
@@ -9410,8 +9439,8 @@ void BoUpSLP::tryToVectorizeGatheredLoads(
94109439
cast<LoadInst>(Slice.front())->getAlign(),
94119440
cast<LoadInst>(Slice.front())
94129441
->getPointerAddressSpace()) &&
9413-
canVectorizeLoads(Slice, Slice.front(), Order,
9414-
PointerOps) == LoadsState::Vectorize) {
9442+
canVectorizeLoads(Slice, Slice.front(), Order, PointerOps,
9443+
SPtrInfo) == LoadsState::Vectorize) {
94159444
UserMaxVF = InterleaveFactor * VF;
94169445
} else {
94179446
InterleaveFactor = 0;
@@ -9433,8 +9462,9 @@ void BoUpSLP::tryToVectorizeGatheredLoads(
94339462
ArrayRef<Value *> VL = TE.Scalars;
94349463
OrdersType Order;
94359464
SmallVector<Value *> PointerOps;
9465+
StridedPtrInfo SPtrInfo;
94369466
LoadsState State = canVectorizeLoads(
9437-
VL, VL.front(), Order, PointerOps);
9467+
VL, VL.front(), Order, PointerOps, SPtrInfo);
94389468
if (State == LoadsState::ScatterVectorize ||
94399469
State == LoadsState::CompressVectorize)
94409470
return false;
@@ -9452,11 +9482,11 @@ void BoUpSLP::tryToVectorizeGatheredLoads(
94529482
[&, Slice = Slice](unsigned Idx) {
94539483
OrdersType Order;
94549484
SmallVector<Value *> PointerOps;
9485+
StridedPtrInfo SPtrInfo;
94559486
return canVectorizeLoads(
94569487
Slice.slice(Idx * UserMaxVF, UserMaxVF),
9457-
Slice[Idx * UserMaxVF], Order,
9458-
PointerOps) ==
9459-
LoadsState::ScatterVectorize;
9488+
Slice[Idx * UserMaxVF], Order, PointerOps,
9489+
SPtrInfo) == LoadsState::ScatterVectorize;
94609490
}))
94619491
UserMaxVF = MaxVF;
94629492
if (Slice.size() != ConsecutiveNodesSize)
@@ -9813,7 +9843,7 @@ getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy,
98139843
BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
98149844
const InstructionsState &S, ArrayRef<Value *> VL,
98159845
bool IsScatterVectorizeUserTE, OrdersType &CurrentOrder,
9816-
SmallVectorImpl<Value *> &PointerOps) {
9846+
SmallVectorImpl<Value *> &PointerOps, StridedPtrInfo &SPtrInfo) {
98179847
assert(S.getMainOp() &&
98189848
"Expected instructions with same/alternate opcodes only.");
98199849

@@ -9915,7 +9945,7 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
99159945
});
99169946
});
99179947
};
9918-
switch (canVectorizeLoads(VL, VL0, CurrentOrder, PointerOps)) {
9948+
switch (canVectorizeLoads(VL, VL0, CurrentOrder, PointerOps, SPtrInfo)) {
99199949
case LoadsState::Vectorize:
99209950
return TreeEntry::Vectorize;
99219951
case LoadsState::CompressVectorize:
@@ -11385,8 +11415,9 @@ void BoUpSLP::buildTreeRec(ArrayRef<Value *> VLRef, unsigned Depth,
1138511415
UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize;
1138611416
OrdersType CurrentOrder;
1138711417
SmallVector<Value *> PointerOps;
11418+
StridedPtrInfo SPtrInfo;
1138811419
TreeEntry::EntryState State = getScalarsVectorizationState(
11389-
S, VL, IsScatterVectorizeUserTE, CurrentOrder, PointerOps);
11420+
S, VL, IsScatterVectorizeUserTE, CurrentOrder, PointerOps, SPtrInfo);
1139011421
if (State == TreeEntry::NeedToGather) {
1139111422
newGatherTreeEntry(VL, S, UserTreeIdx, ReuseShuffleIndices);
1139211423
return;
@@ -11546,6 +11577,7 @@ void BoUpSLP::buildTreeRec(ArrayRef<Value *> VLRef, unsigned Depth,
1154611577
// Vectorizing non-consecutive loads with `llvm.masked.gather`.
1154711578
TE = newTreeEntry(VL, TreeEntry::StridedVectorize, Bundle, S,
1154811579
UserTreeIdx, ReuseShuffleIndices, CurrentOrder);
11580+
TreeEntryToStridedPtrInfoMap[TE] = SPtrInfo;
1154911581
LLVM_DEBUG(dbgs() << "SLP: added a new TreeEntry (strided LoadInst).\n";
1155011582
TE->dump());
1155111583
break;
@@ -12934,8 +12966,9 @@ void BoUpSLP::transformNodes() {
1293412966
if (S.getOpcode() == Instruction::Load) {
1293512967
OrdersType Order;
1293612968
SmallVector<Value *> PointerOps;
12937-
LoadsState Res =
12938-
canVectorizeLoads(Slice, Slice.front(), Order, PointerOps);
12969+
StridedPtrInfo SPtrInfo;
12970+
LoadsState Res = canVectorizeLoads(Slice, Slice.front(), Order,
12971+
PointerOps, SPtrInfo);
1293912972
AllStrided &= Res == LoadsState::StridedVectorize ||
1294012973
Res == LoadsState::ScatterVectorize ||
1294112974
Res == LoadsState::Gather;
@@ -13041,10 +13074,18 @@ void BoUpSLP::transformNodes() {
1304113074
InstructionCost StridedCost = TTI->getStridedMemoryOpCost(
1304213075
Instruction::Load, VecTy, BaseLI->getPointerOperand(),
1304313076
/*VariableMask=*/false, CommonAlignment, CostKind, BaseLI);
13044-
if (StridedCost < OriginalVecCost || ForceStridedLoads)
13077+
if (StridedCost < OriginalVecCost || ForceStridedLoads) {
1304513078
// Strided load is more profitable than consecutive load + reverse -
1304613079
// transform the node to strided load.
13080+
Type *StrideTy = DL->getIndexType(cast<LoadInst>(E.Scalars.front())
13081+
->getPointerOperand()
13082+
->getType());
13083+
StridedPtrInfo SPtrInfo;
13084+
SPtrInfo.StrideVal = ConstantInt::get(StrideTy, 1);
13085+
SPtrInfo.Ty = VecTy;
13086+
TreeEntryToStridedPtrInfoMap[&E] = SPtrInfo;
1304713087
E.State = TreeEntry::StridedVectorize;
13088+
}
1304813089
}
1304913090
break;
1305013091
}
@@ -19485,6 +19526,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
1948519526

1948619527
LoadInst *LI = cast<LoadInst>(VL0);
1948719528
Instruction *NewLI;
19529+
FixedVectorType *StridedLoadTy = nullptr;
1948819530
Value *PO = LI->getPointerOperand();
1948919531
if (E->State == TreeEntry::Vectorize) {
1949019532
NewLI = Builder.CreateAlignedLoad(VecTy, PO, LI->getAlign());
@@ -19522,43 +19564,36 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
1952219564
Value *Ptr0 = cast<LoadInst>(E->Scalars.front())->getPointerOperand();
1952319565
Value *PtrN = cast<LoadInst>(E->Scalars.back())->getPointerOperand();
1952419566
PO = IsReverseOrder ? PtrN : Ptr0;
19525-
std::optional<int64_t> Diff = getPointersDiff(
19526-
VL0->getType(), Ptr0, VL0->getType(), PtrN, *DL, *SE);
1952719567
Type *StrideTy = DL->getIndexType(PO->getType());
1952819568
Value *StrideVal;
19529-
if (Diff) {
19530-
int64_t Stride =
19531-
*Diff / (static_cast<int64_t>(E->Scalars.size()) - 1);
19532-
StrideVal =
19533-
ConstantInt::get(StrideTy, (IsReverseOrder ? -1 : 1) * Stride *
19534-
DL->getTypeAllocSize(ScalarTy));
19535-
} else {
19536-
SmallVector<Value *> PointerOps(E->Scalars.size(), nullptr);
19537-
transform(E->Scalars, PointerOps.begin(), [](Value *V) {
19538-
return cast<LoadInst>(V)->getPointerOperand();
19539-
});
19540-
OrdersType Order;
19541-
const SCEV *StrideSCEV =
19542-
calculateRtStride(PointerOps, ScalarTy, *DL, *SE, Order);
19543-
assert(StrideSCEV && "At this point stride should be known");
19569+
const StridedPtrInfo &SPtrInfo = TreeEntryToStridedPtrInfoMap.at(E);
19570+
StridedLoadTy = SPtrInfo.Ty;
19571+
assert(StridedLoadTy && "Missing StridedPoinerInfo for tree entry.");
19572+
unsigned StridedLoadEC =
19573+
StridedLoadTy->getElementCount().getKnownMinValue();
19574+
19575+
Value *Stride = SPtrInfo.StrideVal;
19576+
if (!Stride) {
19577+
const SCEV *StrideSCEV = SPtrInfo.StrideSCEV;
19578+
assert(StrideSCEV && "Neither StrideVal nor StrideSCEV were set.");
1954419579
SCEVExpander Expander(*SE, *DL, "strided-load-vec");
19545-
Value *Stride = Expander.expandCodeFor(
19546-
StrideSCEV, StrideSCEV->getType(), &*Builder.GetInsertPoint());
19547-
Value *NewStride =
19548-
Builder.CreateIntCast(Stride, StrideTy, /*isSigned=*/true);
19549-
StrideVal = Builder.CreateMul(
19550-
NewStride,
19551-
ConstantInt::get(
19552-
StrideTy,
19553-
(IsReverseOrder ? -1 : 1) *
19554-
static_cast<int>(DL->getTypeAllocSize(ScalarTy))));
19555-
}
19580+
Stride = Expander.expandCodeFor(StrideSCEV, StrideSCEV->getType(),
19581+
&*Builder.GetInsertPoint());
19582+
}
19583+
Value *NewStride =
19584+
Builder.CreateIntCast(Stride, StrideTy, /*isSigned=*/true);
19585+
StrideVal = Builder.CreateMul(
19586+
NewStride, ConstantInt::get(
19587+
StrideTy, (IsReverseOrder ? -1 : 1) *
19588+
static_cast<int>(
19589+
DL->getTypeAllocSize(ScalarTy))));
1955619590
Align CommonAlignment = computeCommonAlignment<LoadInst>(E->Scalars);
1955719591
auto *Inst = Builder.CreateIntrinsic(
1955819592
Intrinsic::experimental_vp_strided_load,
19559-
{VecTy, PO->getType(), StrideTy},
19560-
{PO, StrideVal, Builder.getAllOnesMask(VecTy->getElementCount()),
19561-
Builder.getInt32(E->Scalars.size())});
19593+
{StridedLoadTy, PO->getType(), StrideTy},
19594+
{PO, StrideVal,
19595+
Builder.getAllOnesMask(ElementCount::getFixed(StridedLoadEC)),
19596+
Builder.getInt32(StridedLoadEC)});
1956219597
Inst->addParamAttr(
1956319598
/*ArgNo=*/0,
1956419599
Attribute::getWithAlignment(Inst->getContext(), CommonAlignment));

0 commit comments

Comments
 (0)