Skip to content

Commit e8113fb

Browse files
author
Mikhail Gudim
committed
[SLPVectorizer][NFC] Save stride in a map.
In order to avoid recalculating stride of strided load twice save it in a map.
1 parent ee3a4f4 commit e8113fb

File tree

1 file changed

+103
-67
lines changed

1 file changed

+103
-67
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 103 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1926,6 +1926,19 @@ class BoUpSLP {
19261926
class ShuffleCostEstimator;
19271927
class ShuffleInstructionBuilder;
19281928

1929+
/// If we decide to generate strided load / store, this struct contains all
1930+
/// the necessary info. It's fields are calculated by analyzeRtStrideCandidate
1931+
/// and analyzeConstantStrideCandidate. Note that Stride can be given either
1932+
/// as a SCEV or as a Value if it already exists. To get the stride in bytes,
1933+
/// StrideVal (or value obtained from StrideSCEV) has to by multiplied by the
1934+
/// size of element of FixedVectorType.
1935+
struct StridedPtrInfo {
1936+
Value *StrideVal = nullptr;
1937+
const SCEV *StrideSCEV = nullptr;
1938+
FixedVectorType *Ty = nullptr;
1939+
};
1940+
SmallDenseMap<TreeEntry *, StridedPtrInfo> TreeEntryToStridedPtrInfoMap;
1941+
19291942
public:
19301943
/// Tracks the state we can represent the loads in the given sequence.
19311944
enum class LoadsState {
@@ -2221,6 +2234,11 @@ class BoUpSLP {
22212234
/// TODO: If load combining is allowed in the IR optimizer, this analysis
22222235
/// may not be necessary.
22232236
bool isLoadCombineCandidate(ArrayRef<Value *> Stores) const;
2237+
bool isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
2238+
ArrayRef<unsigned> Order, const TargetTransformInfo &TTI,
2239+
const DataLayout &DL, ScalarEvolution &SE,
2240+
const bool IsAnyPointerUsedOutGraph, const int64_t Diff,
2241+
StridedPtrInfo &SPtrInfo) const;
22242242

22252243
/// Checks if the given array of loads can be represented as a vectorized,
22262244
/// scatter or just simple gather.
@@ -2235,6 +2253,7 @@ class BoUpSLP {
22352253
LoadsState canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
22362254
SmallVectorImpl<unsigned> &Order,
22372255
SmallVectorImpl<Value *> &PointerOps,
2256+
StridedPtrInfo &SPtrInfo,
22382257
unsigned *BestVF = nullptr,
22392258
bool TryRecursiveCheck = true) const;
22402259

@@ -4479,11 +4498,10 @@ class BoUpSLP {
44794498

44804499
/// Checks if the specified list of the instructions/values can be vectorized
44814500
/// and fills required data before actual scheduling of the instructions.
4482-
TreeEntry::EntryState
4483-
getScalarsVectorizationState(const InstructionsState &S, ArrayRef<Value *> VL,
4484-
bool IsScatterVectorizeUserTE,
4485-
OrdersType &CurrentOrder,
4486-
SmallVectorImpl<Value *> &PointerOps);
4501+
TreeEntry::EntryState getScalarsVectorizationState(
4502+
const InstructionsState &S, ArrayRef<Value *> VL,
4503+
bool IsScatterVectorizeUserTE, OrdersType &CurrentOrder,
4504+
SmallVectorImpl<Value *> &PointerOps, StridedPtrInfo &SPtrInfo);
44874505

44884506
/// Maps a specific scalar to its tree entry(ies).
44894507
SmallDenseMap<Value *, SmallVector<TreeEntry *>> ScalarToTreeEntries;
@@ -6456,6 +6474,7 @@ static const SCEV *calculateRtStride(ArrayRef<Value *> PointerOps, Type *ElemTy,
64566474
++Cnt;
64576475
}
64586476
}
6477+
64596478
return Stride;
64606479
}
64616480

@@ -6799,12 +6818,13 @@ isMaskedLoadCompress(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
67996818
/// 4. Any pointer operand is an instruction with the users outside of the
68006819
/// current graph (for masked gathers extra extractelement instructions
68016820
/// might be required).
6802-
static bool isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
6803-
ArrayRef<unsigned> Order,
6804-
const TargetTransformInfo &TTI, const DataLayout &DL,
6805-
ScalarEvolution &SE,
6806-
const bool IsAnyPointerUsedOutGraph,
6807-
const int64_t Diff) {
6821+
bool BoUpSLP::isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
6822+
ArrayRef<unsigned> Order,
6823+
const TargetTransformInfo &TTI,
6824+
const DataLayout &DL, ScalarEvolution &SE,
6825+
const bool IsAnyPointerUsedOutGraph,
6826+
const int64_t Diff,
6827+
StridedPtrInfo &SPtrInfo) const {
68086828
const size_t Sz = VL.size();
68096829
const uint64_t AbsoluteDiff = std::abs(Diff);
68106830
Type *ScalarTy = VL.front()->getType();
@@ -6846,17 +6866,20 @@ static bool isStridedLoad(ArrayRef<Value *> VL, ArrayRef<Value *> PointerOps,
68466866
if (((Dist / Stride) * Stride) != Dist || !Dists.insert(Dist).second)
68476867
break;
68486868
}
6849-
if (Dists.size() == Sz)
6869+
if (Dists.size() == Sz) {
6870+
Type *StrideTy = DL.getIndexType(Ptr0->getType());
6871+
SPtrInfo.StrideVal = ConstantInt::get(StrideTy, Stride);
6872+
SPtrInfo.Ty = getWidenedType(ScalarTy, Sz);
68506873
return true;
6874+
}
68516875
}
68526876
return false;
68536877
}
68546878

6855-
BoUpSLP::LoadsState
6856-
BoUpSLP::canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
6857-
SmallVectorImpl<unsigned> &Order,
6858-
SmallVectorImpl<Value *> &PointerOps,
6859-
unsigned *BestVF, bool TryRecursiveCheck) const {
6879+
BoUpSLP::LoadsState BoUpSLP::canVectorizeLoads(
6880+
ArrayRef<Value *> VL, const Value *VL0, SmallVectorImpl<unsigned> &Order,
6881+
SmallVectorImpl<Value *> &PointerOps, StridedPtrInfo &SPtrInfo,
6882+
unsigned *BestVF, bool TryRecursiveCheck) const {
68606883
// Check that a vectorized load would load the same memory as a scalar
68616884
// load. For example, we don't want to vectorize loads that are smaller
68626885
// than 8-bit. Even though we have a packed struct {<i2, i2, i2, i2>} LLVM
@@ -6894,9 +6917,13 @@ BoUpSLP::canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
68946917
Align CommonAlignment = computeCommonAlignment<LoadInst>(VL);
68956918
if (!IsSorted) {
68966919
if (Sz > MinProfitableStridedLoads && TTI->isTypeLegal(VecTy)) {
6897-
if (TTI->isLegalStridedLoadStore(VecTy, CommonAlignment) &&
6898-
calculateRtStride(PointerOps, ScalarTy, *DL, *SE, Order))
6920+
if (const SCEV *Stride =
6921+
calculateRtStride(PointerOps, ScalarTy, *DL, *SE, Order);
6922+
Stride && TTI->isLegalStridedLoadStore(VecTy, CommonAlignment)) {
6923+
SPtrInfo.Ty = getWidenedType(ScalarTy, PointerOps.size());
6924+
SPtrInfo.StrideSCEV = Stride;
68996925
return LoadsState::StridedVectorize;
6926+
}
69006927
}
69016928

69026929
if (!TTI->isLegalMaskedGather(VecTy, CommonAlignment) ||
@@ -6940,7 +6967,7 @@ BoUpSLP::canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
69406967
});
69416968
if (IsPossibleStrided &&
69426969
isStridedLoad(VL, PointerOps, Order, *TTI, *DL, *SE,
6943-
IsAnyPointerUsedOutGraph, *Diff))
6970+
IsAnyPointerUsedOutGraph, *Diff, SPtrInfo))
69446971
return LoadsState::StridedVectorize;
69456972
}
69466973
if (!TTI->isLegalMaskedGather(VecTy, CommonAlignment) ||
@@ -7024,9 +7051,9 @@ BoUpSLP::canVectorizeLoads(ArrayRef<Value *> VL, const Value *VL0,
70247051
ArrayRef<Value *> Slice = VL.slice(Cnt, VF);
70257052
SmallVector<unsigned> Order;
70267053
SmallVector<Value *> PointerOps;
7027-
LoadsState LS =
7028-
canVectorizeLoads(Slice, Slice.front(), Order, PointerOps, BestVF,
7029-
/*TryRecursiveCheck=*/false);
7054+
LoadsState LS = canVectorizeLoads(Slice, Slice.front(), Order,
7055+
PointerOps, SPtrInfo, BestVF,
7056+
/*TryRecursiveCheck=*/false);
70307057
// Check that the sorted loads are consecutive.
70317058
if (LS == LoadsState::Gather) {
70327059
if (BestVF) {
@@ -7698,9 +7725,10 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom,
76987725
// extra analysis later, so include such nodes into a special list.
76997726
if (TE.hasState() && TE.getOpcode() == Instruction::Load) {
77007727
SmallVector<Value *> PointerOps;
7728+
StridedPtrInfo SPtrInfo;
77017729
OrdersType CurrentOrder;
77027730
LoadsState Res = canVectorizeLoads(TE.Scalars, TE.Scalars.front(),
7703-
CurrentOrder, PointerOps);
7731+
CurrentOrder, PointerOps, SPtrInfo);
77047732
if (Res == LoadsState::Vectorize || Res == LoadsState::StridedVectorize ||
77057733
Res == LoadsState::CompressVectorize)
77067734
return std::move(CurrentOrder);
@@ -9206,8 +9234,9 @@ void BoUpSLP::tryToVectorizeGatheredLoads(
92069234
// Try to build vector load.
92079235
ArrayRef<Value *> Values(
92089236
reinterpret_cast<Value *const *>(Slice.begin()), Slice.size());
9237+
StridedPtrInfo SPtrInfo;
92099238
LoadsState LS = canVectorizeLoads(Values, Slice.front(), CurrentOrder,
9210-
PointerOps, &BestVF);
9239+
PointerOps, SPtrInfo, &BestVF);
92119240
if (LS != LoadsState::Gather ||
92129241
(BestVF > 1 && static_cast<unsigned>(NumElts) == 2 * BestVF)) {
92139242
if (LS == LoadsState::ScatterVectorize) {
@@ -9401,6 +9430,7 @@ void BoUpSLP::tryToVectorizeGatheredLoads(
94019430
unsigned VF = *CommonVF;
94029431
OrdersType Order;
94039432
SmallVector<Value *> PointerOps;
9433+
StridedPtrInfo SPtrInfo;
94049434
// Segmented load detected - vectorize at maximum vector factor.
94059435
if (InterleaveFactor <= Slice.size() &&
94069436
TTI.isLegalInterleavedAccessType(
@@ -9409,8 +9439,8 @@ void BoUpSLP::tryToVectorizeGatheredLoads(
94099439
cast<LoadInst>(Slice.front())->getAlign(),
94109440
cast<LoadInst>(Slice.front())
94119441
->getPointerAddressSpace()) &&
9412-
canVectorizeLoads(Slice, Slice.front(), Order,
9413-
PointerOps) == LoadsState::Vectorize) {
9442+
canVectorizeLoads(Slice, Slice.front(), Order, PointerOps,
9443+
SPtrInfo) == LoadsState::Vectorize) {
94149444
UserMaxVF = InterleaveFactor * VF;
94159445
} else {
94169446
InterleaveFactor = 0;
@@ -9432,8 +9462,9 @@ void BoUpSLP::tryToVectorizeGatheredLoads(
94329462
ArrayRef<Value *> VL = TE.Scalars;
94339463
OrdersType Order;
94349464
SmallVector<Value *> PointerOps;
9465+
StridedPtrInfo SPtrInfo;
94359466
LoadsState State = canVectorizeLoads(
9436-
VL, VL.front(), Order, PointerOps);
9467+
VL, VL.front(), Order, PointerOps, SPtrInfo);
94379468
if (State == LoadsState::ScatterVectorize ||
94389469
State == LoadsState::CompressVectorize)
94399470
return false;
@@ -9451,11 +9482,11 @@ void BoUpSLP::tryToVectorizeGatheredLoads(
94519482
[&, Slice = Slice](unsigned Idx) {
94529483
OrdersType Order;
94539484
SmallVector<Value *> PointerOps;
9485+
StridedPtrInfo SPtrInfo;
94549486
return canVectorizeLoads(
94559487
Slice.slice(Idx * UserMaxVF, UserMaxVF),
9456-
Slice[Idx * UserMaxVF], Order,
9457-
PointerOps) ==
9458-
LoadsState::ScatterVectorize;
9488+
Slice[Idx * UserMaxVF], Order, PointerOps,
9489+
SPtrInfo) == LoadsState::ScatterVectorize;
94599490
}))
94609491
UserMaxVF = MaxVF;
94619492
if (Slice.size() != ConsecutiveNodesSize)
@@ -9812,7 +9843,7 @@ getVectorCallCosts(CallInst *CI, FixedVectorType *VecTy,
98129843
BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
98139844
const InstructionsState &S, ArrayRef<Value *> VL,
98149845
bool IsScatterVectorizeUserTE, OrdersType &CurrentOrder,
9815-
SmallVectorImpl<Value *> &PointerOps) {
9846+
SmallVectorImpl<Value *> &PointerOps, StridedPtrInfo &SPtrInfo) {
98169847
assert(S.getMainOp() &&
98179848
"Expected instructions with same/alternate opcodes only.");
98189849

@@ -9914,7 +9945,7 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
99149945
});
99159946
});
99169947
};
9917-
switch (canVectorizeLoads(VL, VL0, CurrentOrder, PointerOps)) {
9948+
switch (canVectorizeLoads(VL, VL0, CurrentOrder, PointerOps, SPtrInfo)) {
99189949
case LoadsState::Vectorize:
99199950
return TreeEntry::Vectorize;
99209951
case LoadsState::CompressVectorize:
@@ -11384,8 +11415,9 @@ void BoUpSLP::buildTreeRec(ArrayRef<Value *> VLRef, unsigned Depth,
1138411415
UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize;
1138511416
OrdersType CurrentOrder;
1138611417
SmallVector<Value *> PointerOps;
11418+
StridedPtrInfo SPtrInfo;
1138711419
TreeEntry::EntryState State = getScalarsVectorizationState(
11388-
S, VL, IsScatterVectorizeUserTE, CurrentOrder, PointerOps);
11420+
S, VL, IsScatterVectorizeUserTE, CurrentOrder, PointerOps, SPtrInfo);
1138911421
if (State == TreeEntry::NeedToGather) {
1139011422
newGatherTreeEntry(VL, S, UserTreeIdx, ReuseShuffleIndices);
1139111423
return;
@@ -11545,6 +11577,7 @@ void BoUpSLP::buildTreeRec(ArrayRef<Value *> VLRef, unsigned Depth,
1154511577
// Vectorizing non-consecutive loads with `llvm.masked.gather`.
1154611578
TE = newTreeEntry(VL, TreeEntry::StridedVectorize, Bundle, S,
1154711579
UserTreeIdx, ReuseShuffleIndices, CurrentOrder);
11580+
TreeEntryToStridedPtrInfoMap[TE] = SPtrInfo;
1154811581
LLVM_DEBUG(dbgs() << "SLP: added a new TreeEntry (strided LoadInst).\n";
1154911582
TE->dump());
1155011583
break;
@@ -12933,8 +12966,9 @@ void BoUpSLP::transformNodes() {
1293312966
if (S.getOpcode() == Instruction::Load) {
1293412967
OrdersType Order;
1293512968
SmallVector<Value *> PointerOps;
12936-
LoadsState Res =
12937-
canVectorizeLoads(Slice, Slice.front(), Order, PointerOps);
12969+
StridedPtrInfo SPtrInfo;
12970+
LoadsState Res = canVectorizeLoads(Slice, Slice.front(), Order,
12971+
PointerOps, SPtrInfo);
1293812972
AllStrided &= Res == LoadsState::StridedVectorize ||
1293912973
Res == LoadsState::ScatterVectorize ||
1294012974
Res == LoadsState::Gather;
@@ -13043,7 +13077,15 @@ void BoUpSLP::transformNodes() {
1304313077
if (StridedCost < OriginalVecCost || ForceStridedLoads)
1304413078
// Strided load is more profitable than consecutive load + reverse -
1304513079
// transform the node to strided load.
13080+
Type *StrideTy = DL->getIndexType(cast<LoadInst>(E.Scalars.front())
13081+
->getPointerOperand()
13082+
->getType());
13083+
StridedPtrInfo SPtrInfo;
13084+
SPtrInfo.StrideVal = ConstantInt::get(StrideTy, 1);
13085+
SPtrInfo.Ty = VecTy;
13086+
TreeEntryToStridedPtrInfoMap[&E] = SPtrInfo;
1304613087
E.State = TreeEntry::StridedVectorize;
13088+
}
1304713089
}
1304813090
break;
1304913091
}
@@ -19484,6 +19526,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
1948419526

1948519527
LoadInst *LI = cast<LoadInst>(VL0);
1948619528
Instruction *NewLI;
19529+
FixedVectorType *StridedLoadTy = nullptr;
1948719530
Value *PO = LI->getPointerOperand();
1948819531
if (E->State == TreeEntry::Vectorize) {
1948919532
NewLI = Builder.CreateAlignedLoad(VecTy, PO, LI->getAlign());
@@ -19521,43 +19564,36 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) {
1952119564
Value *Ptr0 = cast<LoadInst>(E->Scalars.front())->getPointerOperand();
1952219565
Value *PtrN = cast<LoadInst>(E->Scalars.back())->getPointerOperand();
1952319566
PO = IsReverseOrder ? PtrN : Ptr0;
19524-
std::optional<int64_t> Diff = getPointersDiff(
19525-
VL0->getType(), Ptr0, VL0->getType(), PtrN, *DL, *SE);
1952619567
Type *StrideTy = DL->getIndexType(PO->getType());
1952719568
Value *StrideVal;
19528-
if (Diff) {
19529-
int64_t Stride =
19530-
*Diff / (static_cast<int64_t>(E->Scalars.size()) - 1);
19531-
StrideVal =
19532-
ConstantInt::get(StrideTy, (IsReverseOrder ? -1 : 1) * Stride *
19533-
DL->getTypeAllocSize(ScalarTy));
19534-
} else {
19535-
SmallVector<Value *> PointerOps(E->Scalars.size(), nullptr);
19536-
transform(E->Scalars, PointerOps.begin(), [](Value *V) {
19537-
return cast<LoadInst>(V)->getPointerOperand();
19538-
});
19539-
OrdersType Order;
19540-
const SCEV *StrideSCEV =
19541-
calculateRtStride(PointerOps, ScalarTy, *DL, *SE, Order);
19542-
assert(StrideSCEV && "At this point stride should be known");
19569+
const StridedPtrInfo &SPtrInfo = TreeEntryToStridedPtrInfoMap.at(E);
19570+
StridedLoadTy = SPtrInfo.Ty;
19571+
assert(StridedLoadTy && "Missing StridedPoinerInfo for tree entry.");
19572+
unsigned StridedLoadEC =
19573+
StridedLoadTy->getElementCount().getKnownMinValue();
19574+
19575+
Value *Stride = SPtrInfo.StrideVal;
19576+
if (!Stride) {
19577+
const SCEV *StrideSCEV = SPtrInfo.StrideSCEV;
19578+
assert(StrideSCEV && "Neither StrideVal nor StrideSCEV were set.");
1954319579
SCEVExpander Expander(*SE, *DL, "strided-load-vec");
19544-
Value *Stride = Expander.expandCodeFor(
19545-
StrideSCEV, StrideSCEV->getType(), &*Builder.GetInsertPoint());
19546-
Value *NewStride =
19547-
Builder.CreateIntCast(Stride, StrideTy, /*isSigned=*/true);
19548-
StrideVal = Builder.CreateMul(
19549-
NewStride,
19550-
ConstantInt::get(
19551-
StrideTy,
19552-
(IsReverseOrder ? -1 : 1) *
19553-
static_cast<int>(DL->getTypeAllocSize(ScalarTy))));
19554-
}
19580+
Stride = Expander.expandCodeFor(StrideSCEV, StrideSCEV->getType(),
19581+
&*Builder.GetInsertPoint());
19582+
}
19583+
Value *NewStride =
19584+
Builder.CreateIntCast(Stride, StrideTy, /*isSigned=*/true);
19585+
StrideVal = Builder.CreateMul(
19586+
NewStride, ConstantInt::get(
19587+
StrideTy, (IsReverseOrder ? -1 : 1) *
19588+
static_cast<int>(
19589+
DL->getTypeAllocSize(ScalarTy))));
1955519590
Align CommonAlignment = computeCommonAlignment<LoadInst>(E->Scalars);
1955619591
auto *Inst = Builder.CreateIntrinsic(
1955719592
Intrinsic::experimental_vp_strided_load,
19558-
{VecTy, PO->getType(), StrideTy},
19559-
{PO, StrideVal, Builder.getAllOnesMask(VecTy->getElementCount()),
19560-
Builder.getInt32(E->Scalars.size())});
19593+
{StridedLoadTy, PO->getType(), StrideTy},
19594+
{PO, StrideVal,
19595+
Builder.getAllOnesMask(ElementCount::getFixed(StridedLoadEC)),
19596+
Builder.getInt32(StridedLoadEC)});
1956119597
Inst->addParamAttr(
1956219598
/*ArgNo=*/0,
1956319599
Attribute::getWithAlignment(Inst->getContext(), CommonAlignment));

0 commit comments

Comments
 (0)