Skip to content

Commit 01bb75f

Browse files
committed
[LoopVectorize] Use predicated version of getSmallConstantMaxTripCount
There are a number of places where we call getSmallConstantMaxTripCount without passing a vector of predicates: getSmallBestKnownTC isIndvarOverflowCheckKnownFalse computeMaxVF isMoreProfitable I've changed all of these to now pass in a predicate vector so that we get the benefit of making better vectorisation choices when we know the max trip count for loops that require SCEV predicate checks. I've tried to add tests that cover all the cases affected by these changes.
1 parent 861d786 commit 01bb75f

File tree

4 files changed

+84
-343
lines changed

4 files changed

+84
-343
lines changed

llvm/include/llvm/Analysis/ScalarEvolution.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2376,6 +2376,10 @@ class PredicatedScalarEvolution {
23762376
/// Get the (predicated) symbolic max backedge count for the analyzed loop.
23772377
const SCEV *getSymbolicMaxBackedgeTakenCount();
23782378

2379+
/// Returns the upper bound of the loop trip count as a normal unsigned
2380+
/// value, or 0 if the trip count is unknown.
2381+
unsigned getSmallConstantMaxTripCount();
2382+
23792383
/// Adds a new predicate.
23802384
void addPredicate(const SCEVPredicate &Pred);
23812385

@@ -2447,6 +2451,9 @@ class PredicatedScalarEvolution {
24472451

24482452
/// The symbolic backedge taken count.
24492453
const SCEV *SymbolicMaxBackedgeCount = nullptr;
2454+
2455+
/// The constant max trip count for the loop.
2456+
std::optional<unsigned> SmallConstantMaxTripCount;
24502457
};
24512458

24522459
template <> struct DenseMapInfo<ScalarEvolution::FoldID> {

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15051,6 +15051,16 @@ const SCEV *PredicatedScalarEvolution::getSymbolicMaxBackedgeTakenCount() {
1505115051
return SymbolicMaxBackedgeCount;
1505215052
}
1505315053

15054+
unsigned PredicatedScalarEvolution::getSmallConstantMaxTripCount() {
15055+
if (!SmallConstantMaxTripCount) {
15056+
SmallVector<const SCEVPredicate *, 4> Preds;
15057+
SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15058+
for (const auto *P : Preds)
15059+
addPredicate(*P);
15060+
}
15061+
return *SmallConstantMaxTripCount;
15062+
}
15063+
1505415064
void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) {
1505515065
if (Preds->implies(&Pred))
1505615066
return;

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -411,10 +411,10 @@ static bool hasIrregularType(Type *Ty, const DataLayout &DL) {
411411
/// 3) Returns upper bound estimate if known, and if \p CanUseConstantMax.
412412
/// 4) Returns std::nullopt if all of the above failed.
413413
static std::optional<unsigned>
414-
getSmallBestKnownTC(ScalarEvolution &SE, Loop *L,
414+
getSmallBestKnownTC(PredicatedScalarEvolution &PSE, Loop *L,
415415
bool CanUseConstantMax = true) {
416416
// Check if exact trip count is known.
417-
if (unsigned ExpectedTC = SE.getSmallConstantTripCount(L))
417+
if (unsigned ExpectedTC = PSE.getSE()->getSmallConstantTripCount(L))
418418
return ExpectedTC;
419419

420420
// Check if there is an expected trip count available from profile data.
@@ -426,7 +426,7 @@ getSmallBestKnownTC(ScalarEvolution &SE, Loop *L,
426426
return std::nullopt;
427427

428428
// Check if upper bound estimate is known.
429-
if (unsigned ExpectedTC = SE.getSmallConstantMaxTripCount(L))
429+
if (unsigned ExpectedTC = PSE.getSmallConstantMaxTripCount())
430430
return ExpectedTC;
431431

432432
return std::nullopt;
@@ -1787,12 +1787,15 @@ class GeneratedRTChecks {
17871787

17881788
Loop *OuterLoop = nullptr;
17891789

1790+
PredicatedScalarEvolution &PSE;
1791+
17901792
public:
1791-
GeneratedRTChecks(ScalarEvolution &SE, DominatorTree *DT, LoopInfo *LI,
1792-
TargetTransformInfo *TTI, const DataLayout &DL,
1793-
bool AddBranchWeights)
1794-
: DT(DT), LI(LI), TTI(TTI), SCEVExp(SE, DL, "scev.check"),
1795-
MemCheckExp(SE, DL, "scev.check"), AddBranchWeights(AddBranchWeights) {}
1793+
GeneratedRTChecks(PredicatedScalarEvolution &PSE, DominatorTree *DT,
1794+
LoopInfo *LI, TargetTransformInfo *TTI,
1795+
const DataLayout &DL, bool AddBranchWeights)
1796+
: DT(DT), LI(LI), TTI(TTI), SCEVExp(*PSE.getSE(), DL, "scev.check"),
1797+
MemCheckExp(*PSE.getSE(), DL, "scev.check"),
1798+
AddBranchWeights(AddBranchWeights), PSE(PSE) {}
17961799

17971800
/// Generate runtime checks in SCEVCheckBlock and MemCheckBlock, so we can
17981801
/// accurately estimate the cost of the runtime checks. The blocks are
@@ -1939,7 +1942,7 @@ class GeneratedRTChecks {
19391942

19401943
// Get the best known TC estimate.
19411944
if (auto EstimatedTC = getSmallBestKnownTC(
1942-
*SE, OuterLoop, /* CanUseConstantMax = */ false))
1945+
PSE, OuterLoop, /* CanUseConstantMax = */ false))
19431946
BestTripCount = *EstimatedTC;
19441947

19451948
BestTripCount = std::max(BestTripCount, 1U);
@@ -2270,8 +2273,7 @@ static bool isIndvarOverflowCheckKnownFalse(
22702273
// We know the runtime overflow check is known false iff the (max) trip-count
22712274
// is known and (max) trip-count + (VF * UF) does not overflow in the type of
22722275
// the vector loop induction variable.
2273-
if (unsigned TC =
2274-
Cost->PSE.getSE()->getSmallConstantMaxTripCount(Cost->TheLoop)) {
2276+
if (unsigned TC = Cost->PSE.getSmallConstantMaxTripCount()) {
22752277
uint64_t MaxVF = VF.getKnownMinValue();
22762278
if (VF.isScalable()) {
22772279
std::optional<unsigned> MaxVScale =
@@ -3956,7 +3958,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
39563958
}
39573959

39583960
unsigned TC = PSE.getSE()->getSmallConstantTripCount(TheLoop);
3959-
unsigned MaxTC = PSE.getSE()->getSmallConstantMaxTripCount(TheLoop);
3961+
unsigned MaxTC = PSE.getSmallConstantMaxTripCount();
39603962
LLVM_DEBUG(dbgs() << "LV: Found trip count: " << TC << '\n');
39613963
if (TC != MaxTC)
39623964
LLVM_DEBUG(dbgs() << "LV: Found maximum trip count: " << MaxTC << '\n');
@@ -4253,7 +4255,7 @@ bool LoopVectorizationPlanner::isMoreProfitable(
42534255
InstructionCost CostA = A.Cost;
42544256
InstructionCost CostB = B.Cost;
42554257

4256-
unsigned MaxTripCount = PSE.getSE()->getSmallConstantMaxTripCount(OrigLoop);
4258+
unsigned MaxTripCount = PSE.getSmallConstantMaxTripCount();
42574259

42584260
// Improve estimate for the vector width if it is scalable.
42594261
unsigned EstimatedWidthA = A.Width.getKnownMinValue();
@@ -4841,7 +4843,7 @@ LoopVectorizationCostModel::selectInterleaveCount(ElementCount VF,
48414843
if (!Legal->isSafeForAnyVectorWidth())
48424844
return 1;
48434845

4844-
auto BestKnownTC = getSmallBestKnownTC(*PSE.getSE(), TheLoop);
4846+
auto BestKnownTC = getSmallBestKnownTC(PSE, TheLoop);
48454847
const bool HasReductions = !Legal->getReductionVars().empty();
48464848

48474849
// If we did not calculate the cost for VF (because the user selected the VF)
@@ -9585,8 +9587,8 @@ static bool processLoopInVPlanNativePath(
95859587
{
95869588
bool AddBranchWeights =
95879589
hasBranchWeightMD(*L->getLoopLatch()->getTerminator());
9588-
GeneratedRTChecks Checks(*PSE.getSE(), DT, LI, TTI,
9589-
F->getDataLayout(), AddBranchWeights);
9590+
GeneratedRTChecks Checks(PSE, DT, LI, TTI, F->getDataLayout(),
9591+
AddBranchWeights);
95909592
InnerLoopVectorizer LB(L, PSE, LI, DT, TLI, TTI, AC, ORE, VF.Width,
95919593
VF.Width, 1, LVL, &CM, BFI, PSI, Checks);
95929594
LLVM_DEBUG(dbgs() << "Vectorizing outer loop in \""
@@ -9650,7 +9652,7 @@ static void checkMixedPrecision(Loop *L, OptimizationRemarkEmitter *ORE) {
96509652
static bool areRuntimeChecksProfitable(GeneratedRTChecks &Checks,
96519653
VectorizationFactor &VF,
96529654
std::optional<unsigned> VScale, Loop *L,
9653-
ScalarEvolution &SE,
9655+
PredicatedScalarEvolution &PSE,
96549656
ScalarEpilogueLowering SEL) {
96559657
InstructionCost CheckCost = Checks.getCost();
96569658
if (!CheckCost.isValid())
@@ -9735,7 +9737,7 @@ static bool areRuntimeChecksProfitable(GeneratedRTChecks &Checks,
97359737

97369738
// Skip vectorization if the expected trip count is less than the minimum
97379739
// required trip count.
9738-
if (auto ExpectedTC = getSmallBestKnownTC(SE, L)) {
9740+
if (auto ExpectedTC = getSmallBestKnownTC(PSE, L)) {
97399741
if (ElementCount::isKnownLT(ElementCount::getFixed(*ExpectedTC),
97409742
VF.MinProfitableTripCount)) {
97419743
LLVM_DEBUG(dbgs() << "LV: Vectorization is not beneficial: expected "
@@ -9842,7 +9844,7 @@ bool LoopVectorizePass::processLoop(Loop *L) {
98429844

98439845
// Check the loop for a trip count threshold: vectorize loops with a tiny trip
98449846
// count by optimizing for size, to minimize overheads.
9845-
auto ExpectedTC = getSmallBestKnownTC(*SE, L);
9847+
auto ExpectedTC = getSmallBestKnownTC(PSE, L);
98469848
if (ExpectedTC && *ExpectedTC < TinyTripCountVectorThreshold) {
98479849
LLVM_DEBUG(dbgs() << "LV: Found a loop with a very small trip count. "
98489850
<< "This loop is worth vectorizing only if no scalar "
@@ -9940,8 +9942,8 @@ bool LoopVectorizePass::processLoop(Loop *L) {
99409942

99419943
bool AddBranchWeights =
99429944
hasBranchWeightMD(*L->getLoopLatch()->getTerminator());
9943-
GeneratedRTChecks Checks(*PSE.getSE(), DT, LI, TTI,
9944-
F->getDataLayout(), AddBranchWeights);
9945+
GeneratedRTChecks Checks(PSE, DT, LI, TTI, F->getDataLayout(),
9946+
AddBranchWeights);
99459947
if (LVP.hasPlanWithVF(VF.Width)) {
99469948
// Select the interleave count.
99479949
IC = CM.selectInterleaveCount(VF.Width, VF.Cost);
@@ -9957,7 +9959,7 @@ bool LoopVectorizePass::processLoop(Loop *L) {
99579959
Hints.getForce() == LoopVectorizeHints::FK_Enabled;
99589960
if (!ForceVectorization &&
99599961
!areRuntimeChecksProfitable(Checks, VF, getVScaleForTuning(L, *TTI), L,
9960-
*PSE.getSE(), SEL)) {
9962+
PSE, SEL)) {
99619963
ORE->emit([&]() {
99629964
return OptimizationRemarkAnalysisAliasing(
99639965
DEBUG_TYPE, "CantReorderMemOps", L->getStartLoc(),

0 commit comments

Comments
 (0)