Skip to content

Commit 1a29870

Browse files
committed
[SCEV] Infer loop max trip count from memory accesses
Data references in a loop is assumed to not access elements over the statically allocated size. We can therefore infer a loop max trip count from this undefined behavior. This patch is refined from the orignal one (https://reviews.llvm.org/D155049) authored by @Peakulorain.
1 parent db8b76d commit 1a29870

File tree

4 files changed

+544
-6
lines changed

4 files changed

+544
-6
lines changed

llvm/include/llvm/Analysis/ScalarEvolution.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,6 +1153,10 @@ class ScalarEvolution {
11531153
bool ExitIfTrue, bool ControlsOnlyExit,
11541154
bool AllowPredicates = false);
11551155

1156+
/// Compute the number of times the body of the specific loop will execute via
1157+
/// the memory access inside the loop body.
1158+
ExitLimit computeExitLimitFromMemAccess(const Loop *L);
1159+
11561160
/// A predicate is said to be monotonically increasing if may go from being
11571161
/// false to being true as the loop iterates, but never the other way
11581162
/// around. A predicate is said to be monotonically decreasing if may go
@@ -1804,6 +1808,9 @@ class ScalarEvolution {
18041808
Value *ExitCond, bool ExitIfTrue,
18051809
bool ControlsOnlyExit,
18061810
bool AllowPredicates);
1811+
ExitLimit computeExitLimitFromMemAccessCached(ExitLimitCacheTy &Cache,
1812+
const Loop *L);
1813+
ExitLimit computeExitLimitFromMemAccessImpl(const Loop *L);
18071814
std::optional<ScalarEvolution::ExitLimit> computeExitLimitFromCondFromBinOp(
18081815
ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
18091816
bool ControlsOnlyExit, bool AllowPredicates);

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 236 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,10 @@ static cl::opt<bool> UseContextForNoWrapFlagInference(
249249
cl::desc("Infer nuw/nsw flags using context where suitable"),
250250
cl::init(true));
251251

252+
static cl::opt<bool> UseMemoryAccessUBForBEInference(
253+
"scalar-evolution-infer-max-trip-count-from-memory-access", cl::Hidden,
254+
cl::desc("Infer loop max trip count from memory access"), cl::init(false));
255+
252256
//===----------------------------------------------------------------------===//
253257
// SCEV class definitions
254258
//===----------------------------------------------------------------------===//
@@ -8253,6 +8257,194 @@ ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
82538257
return getSmallConstantTripMultiple(L, ExitCount);
82548258
}
82558259

8260+
/// Collect all load/store instructions that must be executed in every iteration
8261+
/// of loop \p L .
8262+
static void
8263+
collectExecLoadStoreInsideLoop(const Loop *L, DominatorTree &DT,
8264+
SmallVector<Instruction *, 4> &MemInsts) {
8265+
// It is difficult to tell if the load/store instruction is executed on every
8266+
// iteration inside an irregular loop.
8267+
if (!L->isLoopSimplifyForm() || !L->isInnermost())
8268+
return;
8269+
8270+
const BasicBlock *LoopLatch = L->getLoopLatch();
8271+
assert(LoopLatch && "normal form loop doesn't have a latch");
8272+
assert(L->getExitingBlock() == LoopLatch);
8273+
8274+
// We will not continue if sanitizer is enabled.
8275+
const Function *F = LoopLatch->getParent();
8276+
if (F->hasFnAttribute(Attribute::SanitizeAddress) ||
8277+
F->hasFnAttribute(Attribute::SanitizeThread) ||
8278+
F->hasFnAttribute(Attribute::SanitizeMemory) ||
8279+
F->hasFnAttribute(Attribute::SanitizeHWAddress) ||
8280+
F->hasFnAttribute(Attribute::SanitizeMemTag))
8281+
return;
8282+
8283+
for (auto *BB : L->getBlocks()) {
8284+
// We need to make sure that max execution time of MemAccessBB in loop
8285+
// represents latch max excution time. The BB below should be skipped:
8286+
// Entry
8287+
// │
8288+
// ┌─────▼─────┐
8289+
// │Loop Header◄─────┐
8290+
// └──┬──────┬─┘ │
8291+
// │ │ │
8292+
// ┌────────▼──┐ ┌─▼─────┐ │
8293+
// │MemAccessBB│ │OtherBB│ │
8294+
// └────────┬──┘ └─┬─────┘ │
8295+
// │ │ │
8296+
// ┌─▼──────▼─┐ │
8297+
// │Loop Latch├─────┘
8298+
// └────┬─────┘
8299+
// ▼
8300+
// Exit
8301+
if (!DT.dominates(BB, LoopLatch))
8302+
continue;
8303+
8304+
for (Instruction &I : *BB) {
8305+
if (isa<LoadInst>(&I) || isa<StoreInst>(&I))
8306+
MemInsts.push_back(&I);
8307+
}
8308+
}
8309+
}
8310+
8311+
/// Return a SCEV representing the memory size of pointer \p V .
8312+
static const SCEV *getCertainSizeOfMem(const SCEV *V, Type *RTy,
8313+
const DataLayout &DL,
8314+
const TargetLibraryInfo &TLI,
8315+
ScalarEvolution *SE) {
8316+
const SCEVUnknown *PtrBase = dyn_cast<SCEVUnknown>(V);
8317+
if (!PtrBase)
8318+
return nullptr;
8319+
Value *Ptr = PtrBase->getValue();
8320+
uint64_t Size = 0;
8321+
if (!llvm::getObjectSize(Ptr, Size, DL, &TLI))
8322+
return nullptr;
8323+
return SE->getConstant(RTy, Size);
8324+
}
8325+
8326+
/// Get the range of given index represented by \p AddRec.
8327+
static const SCEV *getIndexRange(const SCEVAddRecExpr *AddRec,
8328+
ScalarEvolution *SE) {
8329+
const SCEV *Range = SE->getConstant(SE->getUnsignedRangeMax(AddRec) -
8330+
SE->getUnsignedRangeMin(AddRec));
8331+
const SCEV *Step = AddRec->getStepRecurrence(*SE);
8332+
if (SE->isKnownNegative(Step))
8333+
Step = SE->getNegativeSCEV(Step);
8334+
return SE->getUDivCeilSCEV(Range, Step);
8335+
}
8336+
8337+
/// Get the underlying SCEVAddExpr from a cast expression if possible.
8338+
const SCEV *peelCastExpr(const SCEVCastExpr *S, ScalarEvolution *SE) {
8339+
const SCEV *Op = S->getOperand();
8340+
if (isa<SCEVCouldNotCompute>(Op))
8341+
return Op;
8342+
if (isa<SCEVAddExpr>(Op))
8343+
return Op;
8344+
if (isa<SCEVCastExpr>(Op))
8345+
return peelCastExpr(cast<SCEVCastExpr>(Op), SE);
8346+
return SE->getCouldNotCompute();
8347+
}
8348+
8349+
/// Check whether the index can wrap and if we can still infer max trip count
8350+
/// given the max trip count inferred from memory access.
8351+
static const SCEV *checkIndexRange(Value *Ptr, ScalarEvolution *SE,
8352+
const SCEVConstant *MaxExecCount) {
8353+
SmallVector<const SCEV *> InferCountColl;
8354+
auto *PtrGEP = dyn_cast<GetElementPtrInst>(Ptr);
8355+
if (!PtrGEP)
8356+
return SE->getCouldNotCompute();
8357+
for (Value *Index : PtrGEP->indices()) {
8358+
Value *V = Index;
8359+
if (isa<ZExtInst>(V) || isa<SExtInst>(V))
8360+
V = cast<Instruction>(Index)->getOperand(0);
8361+
auto *SCEV = SE->getSCEV(V);
8362+
if (isa<SCEVCouldNotCompute>(SCEV))
8363+
return SE->getCouldNotCompute();
8364+
if (isa<SCEVConstant>(SCEV))
8365+
continue;
8366+
if (isa<SCEVCastExpr>(SCEV)) {
8367+
SCEV = peelCastExpr(cast<SCEVCastExpr>(SCEV), SE);
8368+
if (isa<SCEVCouldNotCompute>(SCEV))
8369+
return SE->getCouldNotCompute();
8370+
}
8371+
auto *AddRec = dyn_cast<SCEVAddRecExpr>(SCEV);
8372+
if (!AddRec)
8373+
return SE->getCouldNotCompute();
8374+
auto *IndexRange = getIndexRange(AddRec, SE);
8375+
auto *IndexRangeC = dyn_cast<SCEVConstant>(IndexRange);
8376+
if (!IndexRangeC)
8377+
return SE->getCouldNotCompute();
8378+
InferCountColl.push_back(IndexRange);
8379+
break;
8380+
}
8381+
8382+
if (InferCountColl.empty())
8383+
return SE->getCouldNotCompute();
8384+
8385+
InferCountColl.push_back(MaxExecCount);
8386+
8387+
return SE->getUMinFromMismatchedTypes(InferCountColl);
8388+
}
8389+
8390+
ScalarEvolution::ExitLimit
8391+
ScalarEvolution::computeExitLimitFromMemAccessImpl(const Loop *L) {
8392+
SmallVector<Instruction *, 4> MemInsts;
8393+
collectExecLoadStoreInsideLoop(L, DT, MemInsts);
8394+
8395+
SmallVector<const SCEV *> InferCountColl;
8396+
const DataLayout &DL = getDataLayout();
8397+
8398+
for (Instruction *I : MemInsts) {
8399+
Value *Ptr = getLoadStorePointerOperand(I);
8400+
assert(Ptr && "empty pointer operand");
8401+
auto *AddRec = dyn_cast<SCEVAddRecExpr>(getSCEV(Ptr));
8402+
if (!AddRec || !AddRec->isAffine())
8403+
continue;
8404+
const SCEV *PtrBase = getPointerBase(AddRec);
8405+
const SCEV *Step = AddRec->getStepRecurrence(*this);
8406+
const SCEV *MemSize =
8407+
getCertainSizeOfMem(PtrBase, Step->getType(), DL, TLI, this);
8408+
if (!MemSize)
8409+
continue;
8410+
if (isKnownNegative(Step))
8411+
Step = getNegativeSCEV(Step);
8412+
// Now we can infer a max execution time by MemLength/StepLength.
8413+
auto *MaxExecCount = dyn_cast<SCEVConstant>(getUDivCeilSCEV(MemSize, Step));
8414+
if (!MaxExecCount || MaxExecCount->getAPInt().getActiveBits() > 32)
8415+
continue;
8416+
auto *Res = checkIndexRange(Ptr, this, MaxExecCount);
8417+
if (isa<SCEVCouldNotCompute>(Res))
8418+
continue;
8419+
InferCountColl.push_back(Res);
8420+
}
8421+
8422+
if (InferCountColl.empty())
8423+
return getCouldNotCompute();
8424+
8425+
const SCEV *Count = getUMinFromMismatchedTypes(InferCountColl);
8426+
8427+
return {getCouldNotCompute(), Count, Count, /*MaxOrZero=*/false};
8428+
}
8429+
8430+
ScalarEvolution::ExitLimit
8431+
ScalarEvolution::computeExitLimitFromMemAccessCached(ExitLimitCacheTy &Cache,
8432+
const Loop *L) {
8433+
// We don't really need them but the cache does.
8434+
constexpr Value *ExitCond = nullptr;
8435+
constexpr const bool ExitIfTrue = true;
8436+
constexpr const bool ControlsOnlyExit = true;
8437+
constexpr const bool AllowPredicates = true;
8438+
8439+
if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
8440+
AllowPredicates))
8441+
return *MaybeEL;
8442+
8443+
ExitLimit EL = computeExitLimitFromMemAccessImpl(L);
8444+
Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
8445+
return EL;
8446+
}
8447+
82568448
const SCEV *ScalarEvolution::getExitCount(const Loop *L,
82578449
const BasicBlock *ExitingBlock,
82588450
ExitCountKind Kind) {
@@ -8835,6 +9027,16 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
88359027
if (!Latch || !DT.dominates(ExitingBlock, Latch))
88369028
return getCouldNotCompute();
88379029

9030+
// FIXME: To make the case more typical, we only analyze loops that have one
9031+
// exiting block and the block must be the latch. It is easier to capture
9032+
// loops with memory access that will be executed in every iteration.
9033+
const SCEV *PotentiallyBetterConstantMax = getCouldNotCompute();
9034+
if (UseMemoryAccessUBForBEInference && Latch == L->getExitingBlock()) {
9035+
assert(Latch == ExitingBlock);
9036+
auto EL = computeExitLimitFromMemAccess(L);
9037+
PotentiallyBetterConstantMax = EL.ConstantMaxNotTaken;
9038+
}
9039+
88389040
bool IsOnlyExit = (L->getExitingBlock() != nullptr);
88399041
Instruction *Term = ExitingBlock->getTerminator();
88409042
if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
@@ -8843,9 +9045,14 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
88439045
assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
88449046
"It should have one successor in loop and one exit block!");
88459047
// Proceed to the next level to examine the exit condition expression.
8846-
return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
8847-
/*ControlsOnlyExit=*/IsOnlyExit,
8848-
AllowPredicates);
9048+
ExitLimit EL = computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
9049+
/*ControlsOnlyExit=*/IsOnlyExit,
9050+
AllowPredicates);
9051+
if (!isa<SCEVCouldNotCompute>(EL.ConstantMaxNotTaken) &&
9052+
!isa<SCEVCouldNotCompute>(PotentiallyBetterConstantMax))
9053+
EL.ConstantMaxNotTaken = getUMinFromMismatchedTypes(
9054+
EL.ConstantMaxNotTaken, PotentiallyBetterConstantMax);
9055+
return EL;
88499056
}
88509057

88519058
if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
@@ -8858,9 +9065,14 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
88589065
Exit = SBB;
88599066
}
88609067
assert(Exit && "Exiting block must have at least one exit");
8861-
return computeExitLimitFromSingleExitSwitch(
8862-
L, SI, Exit,
8863-
/*ControlsOnlyExit=*/IsOnlyExit);
9068+
ExitLimit EL =
9069+
computeExitLimitFromSingleExitSwitch(L, SI, Exit,
9070+
/*ControlsOnlyExit=*/IsOnlyExit);
9071+
if (!isa<SCEVCouldNotCompute>(EL.ConstantMaxNotTaken) &&
9072+
!isa<SCEVCouldNotCompute>(PotentiallyBetterConstantMax))
9073+
EL.ConstantMaxNotTaken = getUMinFromMismatchedTypes(
9074+
EL.ConstantMaxNotTaken, PotentiallyBetterConstantMax);
9075+
return EL;
88649076
}
88659077

88669078
return getCouldNotCompute();
@@ -8874,6 +9086,13 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond(
88749086
ControlsOnlyExit, AllowPredicates);
88759087
}
88769088

9089+
ScalarEvolution::ExitLimit
9090+
ScalarEvolution::computeExitLimitFromMemAccess(const Loop *L) {
9091+
ScalarEvolution::ExitLimitCacheTy Cache(L, /* ExitIfTrue */ true,
9092+
/* AllowPredicates */ true);
9093+
return computeExitLimitFromMemAccessCached(Cache, L);
9094+
}
9095+
88779096
std::optional<ScalarEvolution::ExitLimit>
88789097
ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
88799098
bool ExitIfTrue, bool ControlsOnlyExit,
@@ -13544,6 +13763,17 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
1354413763
OS << ": ";
1354513764
OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
1354613765
}
13766+
13767+
if (UseMemoryAccessUBForBEInference) {
13768+
unsigned SmallMaxTrip = SE->getSmallConstantMaxTripCount(L);
13769+
OS << "Loop ";
13770+
L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13771+
OS << ": ";
13772+
if (SmallMaxTrip)
13773+
OS << "Small constant max trip is " << SmallMaxTrip << "\n";
13774+
else
13775+
OS << "Small constant max trip couldn't be computed.\n";
13776+
}
1354713777
}
1354813778

1354913779
namespace llvm {

0 commit comments

Comments
 (0)