Skip to content

Commit 7a0bf19

Browse files
committed
Rework to memoize loop guards across multiple exits
1 parent 9a64d8a commit 7a0bf19

File tree

5 files changed

+150
-119
lines changed

5 files changed

+150
-119
lines changed

llvm/include/llvm/Analysis/LoopAccessAnalysis.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,9 @@ class MemoryDepChecker {
334334
std::pair<const SCEV *, const SCEV *>>
335335
PointerBounds;
336336

337+
/// Cache for the loop guards of InnermostLoop.
338+
std::optional<ScalarEvolution::LoopGuards> LoopGuards;
339+
337340
/// Check whether there is a plausible dependence between the two
338341
/// accesses.
339342
///

llvm/include/llvm/Analysis/ScalarEvolution.h

Lines changed: 61 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,6 +1112,46 @@ class ScalarEvolution {
11121112
bool isKnownOnEveryIteration(ICmpInst::Predicate Pred,
11131113
const SCEVAddRecExpr *LHS, const SCEV *RHS);
11141114

1115+
class LoopGuards {
1116+
DenseMap<const SCEV *, const SCEV *> RewriteMap;
1117+
bool PreserveNUW = false;
1118+
bool PreserveNSW = false;
1119+
ScalarEvolution &SE;
1120+
1121+
LoopGuards(ScalarEvolution &SE) : SE(SE) {}
1122+
1123+
/// Recursively collect loop guards in \p Guards, starting from
1124+
/// block \p Block with predecessor \p Pred. The intended starting point
1125+
/// is to collect from a loop header and its predecessor.
1126+
static void
1127+
collectFromBlock(ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
1128+
const BasicBlock *Block, const BasicBlock *Pred,
1129+
SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks,
1130+
unsigned Depth = 0);
1131+
1132+
/// Collect loop guards in \p Guards, starting from PHINode \p
1133+
/// Phi, by calling \p collectFromBlock on the incoming blocks of
1134+
/// \Phi and trying to merge the found constraints into a single
1135+
/// combined one for \p Phi.
1136+
static void collectFromPHI(
1137+
ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
1138+
const PHINode &Phi, SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks,
1139+
SmallDenseMap<const BasicBlock *, LoopGuards> &IncomingGuards,
1140+
unsigned Depth);
1141+
1142+
public:
1143+
/// Collect rewrite map for loop guards for loop \p L, together with flags
1144+
/// indicating if NUW and NSW can be preserved during rewriting.
1145+
static LoopGuards collect(const Loop *L, ScalarEvolution &SE);
1146+
1147+
/// Try to apply the collected loop guards to \p Expr.
1148+
const SCEV *rewrite(const SCEV *Expr) const;
1149+
};
1150+
1151+
/// Try to apply information from loop guards for \p L to \p Expr.
1152+
const SCEV *applyLoopGuards(const SCEV *Expr, const Loop *L);
1153+
const SCEV *applyLoopGuards(const SCEV *Expr, const LoopGuards &Guards);
1154+
11151155
/// Information about the number of loop iterations for which a loop exit's
11161156
/// branch condition evaluates to the not-taken path. This is a temporary
11171157
/// pair of exact and max expressions that are eventually summarized in
@@ -1167,6 +1207,7 @@ class ScalarEvolution {
11671207
/// If \p AllowPredicates is set, this call will try to use a minimal set of
11681208
/// SCEV predicates in order to return an exact answer.
11691209
ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond,
1210+
std::function<LoopGuards()> GetLoopGuards,
11701211
bool ExitIfTrue, bool ControlsOnlyExit,
11711212
bool AllowPredicates = false);
11721213

@@ -1308,45 +1349,6 @@ class ScalarEvolution {
13081349
/// sharpen it.
13091350
void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags);
13101351

1311-
class LoopGuards {
1312-
DenseMap<const SCEV *, const SCEV *> RewriteMap;
1313-
bool PreserveNUW = false;
1314-
bool PreserveNSW = false;
1315-
ScalarEvolution &SE;
1316-
1317-
LoopGuards(ScalarEvolution &SE) : SE(SE) {}
1318-
1319-
/// Recursively collect loop guards in \p Guards, starting from
1320-
/// block \p Block with predecessor \p Pred. The intended starting point
1321-
/// is to collect from a loop header and its predecessor.
1322-
static void
1323-
collectFromBlock(ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
1324-
const BasicBlock *Block, const BasicBlock *Pred,
1325-
SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks,
1326-
unsigned Depth = 0);
1327-
1328-
/// Collect loop guards in \p Guards, starting from PHINode \p
1329-
/// Phi, by calling \p collectFromBlock on the incoming blocks of
1330-
/// \Phi and trying to merge the found constraints into a single
1331-
/// combined one for \p Phi.
1332-
static void collectFromPHI(
1333-
ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
1334-
const PHINode &Phi, SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks,
1335-
SmallDenseMap<const BasicBlock *, LoopGuards> &IncomingGuards,
1336-
unsigned Depth);
1337-
1338-
public:
1339-
/// Collect rewrite map for loop guards for loop \p L, together with flags
1340-
/// indicating if NUW and NSW can be preserved during rewriting.
1341-
static LoopGuards collect(const Loop *L, ScalarEvolution &SE);
1342-
1343-
/// Try to apply the collected loop guards to \p Expr.
1344-
const SCEV *rewrite(const SCEV *Expr) const;
1345-
};
1346-
1347-
/// Try to apply information from loop guards for \p L to \p Expr.
1348-
const SCEV *applyLoopGuards(const SCEV *Expr, const Loop *L);
1349-
13501352
/// Return true if the loop has no abnormal exits. That is, if the loop
13511353
/// is not infinite, it must exit through an explicit edge in the CFG.
13521354
/// (As opposed to either a) throwing out of the function or b) entering a
@@ -1650,10 +1652,6 @@ class ScalarEvolution {
16501652
/// function as they are computed.
16511653
DenseMap<const Loop *, BackedgeTakenInfo> PredicatedBackedgeTakenCounts;
16521654

1653-
/// Cache the collected loop guards of the loops of this function as they are
1654-
/// computed.
1655-
DenseMap<const Loop *, LoopGuards> LoopGuardsCache;
1656-
16571655
/// Loops whose backedge taken counts directly use this non-constant SCEV.
16581656
DenseMap<const SCEV *, SmallPtrSet<PointerIntPair<const Loop *, 1, bool>, 4>>
16591657
BECountUsers;
@@ -1843,6 +1841,7 @@ class ScalarEvolution {
18431841
/// this call will try to use a minimal set of SCEV predicates in order to
18441842
/// return an exact answer.
18451843
ExitLimit computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
1844+
std::function<LoopGuards()> GetLoopGuards,
18461845
bool IsOnlyExit, bool AllowPredicates = false);
18471846

18481847
// Helper functions for computeExitLimitFromCond to avoid exponential time
@@ -1875,17 +1874,17 @@ class ScalarEvolution {
18751874

18761875
using ExitLimitCacheTy = ExitLimitCache;
18771876

1878-
ExitLimit computeExitLimitFromCondCached(ExitLimitCacheTy &Cache,
1879-
const Loop *L, Value *ExitCond,
1880-
bool ExitIfTrue,
1881-
bool ControlsOnlyExit,
1882-
bool AllowPredicates);
1883-
ExitLimit computeExitLimitFromCondImpl(ExitLimitCacheTy &Cache, const Loop *L,
1884-
Value *ExitCond, bool ExitIfTrue,
1885-
bool ControlsOnlyExit,
1886-
bool AllowPredicates);
1877+
ExitLimit computeExitLimitFromCondCached(
1878+
ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
1879+
std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
1880+
bool ControlsOnlyExit, bool AllowPredicates);
1881+
ExitLimit computeExitLimitFromCondImpl(
1882+
ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
1883+
std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
1884+
bool ControlsOnlyExit, bool AllowPredicates);
18871885
std::optional<ScalarEvolution::ExitLimit> computeExitLimitFromCondFromBinOp(
1888-
ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
1886+
ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond,
1887+
std::function<LoopGuards()> GetLoopGuards, bool ExitIfTrue,
18891888
bool ControlsOnlyExit, bool AllowPredicates);
18901889

18911890
/// Compute the number of times the backedge of the specified loop will
@@ -1894,8 +1893,8 @@ class ScalarEvolution {
18941893
/// to use a minimal set of SCEV predicates in order to return an exact
18951894
/// answer.
18961895
ExitLimit computeExitLimitFromICmp(const Loop *L, ICmpInst *ExitCond,
1897-
bool ExitIfTrue,
1898-
bool IsSubExpr,
1896+
std::function<LoopGuards()> GetLoopGuards,
1897+
bool ExitIfTrue, bool IsSubExpr,
18991898
bool AllowPredicates = false);
19001899

19011900
/// Variant of previous which takes the components representing an ICmp
@@ -1904,16 +1903,16 @@ class ScalarEvolution {
19041903
/// has a materialized ICmp.
19051904
ExitLimit computeExitLimitFromICmp(const Loop *L, ICmpInst::Predicate Pred,
19061905
const SCEV *LHS, const SCEV *RHS,
1906+
std::function<LoopGuards()> GetLoopGuards,
19071907
bool IsSubExpr,
19081908
bool AllowPredicates = false);
19091909

19101910
/// Compute the number of times the backedge of the specified loop will
19111911
/// execute if its exit condition were a switch with a single exiting case
19121912
/// to ExitingBB.
1913-
ExitLimit computeExitLimitFromSingleExitSwitch(const Loop *L,
1914-
SwitchInst *Switch,
1915-
BasicBlock *ExitingBB,
1916-
bool IsSubExpr);
1913+
ExitLimit computeExitLimitFromSingleExitSwitch(
1914+
const Loop *L, SwitchInst *Switch, BasicBlock *ExitingBB,
1915+
std::function<LoopGuards()> GetLoopGuards, bool IsSubExpr);
19171916

19181917
/// Compute the exit limit of a loop that is controlled by a
19191918
/// "(IV >> 1) != 0" type comparison. We cannot compute the exact trip
@@ -1937,8 +1936,9 @@ class ScalarEvolution {
19371936
/// value to zero will execute. If not computable, return CouldNotCompute.
19381937
/// If AllowPredicates is set, this call will try to use a minimal set of
19391938
/// SCEV predicates in order to return an exact answer.
1940-
ExitLimit howFarToZero(const SCEV *V, const Loop *L, bool IsSubExpr,
1941-
bool AllowPredicates = false);
1939+
ExitLimit howFarToZero(const SCEV *V, const Loop *L,
1940+
std::function<LoopGuards()> GetLoopGuards,
1941+
bool IsSubExpr, bool AllowPredicates = false);
19421942

19431943
/// Return the number of times an exit condition checking the specified
19441944
/// value for nonzero will execute. If not computable, return

llvm/lib/Analysis/LoopAccessAnalysis.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1945,13 +1945,16 @@ MemoryDepChecker::getDependenceDistanceStrideAndSize(
19451945
!isa<SCEVCouldNotCompute>(SrcEnd_) &&
19461946
!isa<SCEVCouldNotCompute>(SinkStart_) &&
19471947
!isa<SCEVCouldNotCompute>(SinkEnd_)) {
1948-
auto SrcEnd = SE.applyLoopGuards(SrcEnd_, InnermostLoop);
1949-
auto SinkStart = SE.applyLoopGuards(SinkStart_, InnermostLoop);
1948+
if (!LoopGuards)
1949+
LoopGuards.emplace(
1950+
ScalarEvolution::LoopGuards::collect(InnermostLoop, SE));
1951+
auto SrcEnd = SE.applyLoopGuards(SrcEnd_, *LoopGuards);
1952+
auto SinkStart = SE.applyLoopGuards(SinkStart_, *LoopGuards);
19501953
if (SE.isKnownPredicate(CmpInst::ICMP_ULE, SrcEnd, SinkStart))
19511954
return MemoryDepChecker::Dependence::NoDep;
19521955

1953-
auto SinkEnd = SE.applyLoopGuards(SinkEnd_, InnermostLoop);
1954-
auto SrcStart = SE.applyLoopGuards(SrcStart_, InnermostLoop);
1956+
auto SinkEnd = SE.applyLoopGuards(SinkEnd_, *LoopGuards);
1957+
auto SrcStart = SE.applyLoopGuards(SrcStart_, *LoopGuards);
19551958
if (SE.isKnownPredicate(CmpInst::ICMP_ULE, SinkEnd, SrcStart))
19561959
return MemoryDepChecker::Dependence::NoDep;
19571960
}
@@ -2054,7 +2057,10 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx,
20542057
return Dependence::NoDep;
20552058
}
20562059
} else {
2057-
Dist = SE.applyLoopGuards(Dist, InnermostLoop);
2060+
if (!LoopGuards)
2061+
LoopGuards.emplace(
2062+
ScalarEvolution::LoopGuards::collect(InnermostLoop, SE));
2063+
Dist = SE.applyLoopGuards(Dist, *LoopGuards);
20582064
}
20592065

20602066
// Negative distances are not plausible dependencies.

0 commit comments

Comments
 (0)