Skip to content

Commit 4d84e41

Browse files
committed
[SCEV] Infer loop max trip count from memory accesses
Data references in a loop is assumed to not access elements over the statically allocated size. We can therefore infer a loop max trip count from this undefined behavior. This patch is refined from the orignal one (https://reviews.llvm.org/D155049) authored by @Peakulorain.
1 parent 074fe3b commit 4d84e41

File tree

4 files changed

+560
-6
lines changed

4 files changed

+560
-6
lines changed

llvm/include/llvm/Analysis/ScalarEvolution.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,6 +1154,10 @@ class ScalarEvolution {
11541154
bool ExitIfTrue, bool ControlsOnlyExit,
11551155
bool AllowPredicates = false);
11561156

1157+
/// Compute the number of times the body of the specific loop will execute via
1158+
/// the memory access inside the loop body.
1159+
ExitLimit computeExitLimitFromMemAccess(const Loop *L);
1160+
11571161
/// A predicate is said to be monotonically increasing if may go from being
11581162
/// false to being true as the loop iterates, but never the other way
11591163
/// around. A predicate is said to be monotonically decreasing if may go
@@ -1805,6 +1809,9 @@ class ScalarEvolution {
18051809
Value *ExitCond, bool ExitIfTrue,
18061810
bool ControlsOnlyExit,
18071811
bool AllowPredicates);
1812+
ExitLimit computeExitLimitFromMemAccessCached(ExitLimitCacheTy &Cache,
1813+
const Loop *L);
1814+
ExitLimit computeExitLimitFromMemAccessImpl(const Loop *L);
18081815
std::optional<ScalarEvolution::ExitLimit> computeExitLimitFromCondFromBinOp(
18091816
ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
18101817
bool ControlsOnlyExit, bool AllowPredicates);

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 252 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,10 @@ static cl::opt<bool> UseContextForNoWrapFlagInference(
249249
cl::desc("Infer nuw/nsw flags using context where suitable"),
250250
cl::init(true));
251251

252+
static cl::opt<bool> UseMemoryAccessUBForBEInference(
253+
"scalar-evolution-infer-max-trip-count-from-memory-access", cl::Hidden,
254+
cl::desc("Infer loop max trip count from memory access"), cl::init(false));
255+
252256
//===----------------------------------------------------------------------===//
253257
// SCEV class definitions
254258
//===----------------------------------------------------------------------===//
@@ -8260,6 +8264,210 @@ ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
82608264
return getSmallConstantTripMultiple(L, ExitCount);
82618265
}
82628266

8267+
/// Collect all load/store instructions that must be executed in every iteration
8268+
/// of loop \p L .
8269+
static void
8270+
collectExecLoadStoreInsideLoop(const Loop *L, DominatorTree &DT,
8271+
SmallVector<Instruction *, 4> &MemInsts) {
8272+
// It is difficult to tell if the load/store instruction is executed on every
8273+
// iteration inside an irregular loop.
8274+
if (!L->isLoopSimplifyForm() || !L->isInnermost())
8275+
return;
8276+
8277+
const BasicBlock *LoopLatch = L->getLoopLatch();
8278+
assert(LoopLatch && "normal form loop doesn't have a latch");
8279+
assert(L->getExitingBlock() == LoopLatch);
8280+
8281+
// We will not continue if sanitizer is enabled.
8282+
const Function *F = LoopLatch->getParent();
8283+
if (F->hasFnAttribute(Attribute::SanitizeAddress) ||
8284+
F->hasFnAttribute(Attribute::SanitizeThread) ||
8285+
F->hasFnAttribute(Attribute::SanitizeMemory) ||
8286+
F->hasFnAttribute(Attribute::SanitizeHWAddress) ||
8287+
F->hasFnAttribute(Attribute::SanitizeMemTag))
8288+
return;
8289+
8290+
for (auto *BB : L->getBlocks()) {
8291+
// We need to make sure that max execution time of MemAccessBB in loop
8292+
// represents latch max excution time. The BB below should be skipped:
8293+
// Entry
8294+
// │
8295+
// ┌─────▼─────┐
8296+
// │Loop Header◄─────┐
8297+
// └──┬──────┬─┘ │
8298+
// │ │ │
8299+
// ┌────────▼──┐ ┌─▼─────┐ │
8300+
// │MemAccessBB│ │OtherBB│ │
8301+
// └────────┬──┘ └─┬─────┘ │
8302+
// │ │ │
8303+
// ┌─▼──────▼─┐ │
8304+
// │Loop Latch├─────┘
8305+
// └────┬─────┘
8306+
// ▼
8307+
// Exit
8308+
if (!DT.dominates(BB, LoopLatch))
8309+
continue;
8310+
8311+
for (Instruction &I : *BB) {
8312+
if (isa<LoadInst>(&I) || isa<StoreInst>(&I))
8313+
MemInsts.push_back(&I);
8314+
}
8315+
}
8316+
}
8317+
8318+
/// Return a SCEV representing the memory size of pointer \p V .
8319+
static const SCEV *getCertainSizeOfMem(const SCEV *V, Type *RTy,
8320+
const DataLayout &DL,
8321+
const TargetLibraryInfo &TLI,
8322+
ScalarEvolution *SE) {
8323+
const SCEVUnknown *PtrBase = dyn_cast<SCEVUnknown>(V);
8324+
if (!PtrBase)
8325+
return nullptr;
8326+
Value *Ptr = PtrBase->getValue();
8327+
uint64_t Size = 0;
8328+
if (!llvm::getObjectSize(Ptr, Size, DL, &TLI))
8329+
return nullptr;
8330+
return SE->getConstant(RTy, Size);
8331+
}
8332+
8333+
/// Get the range of given index represented by \p AddRec.
8334+
static const SCEV *getIndexRange(const SCEVAddRecExpr *AddRec,
8335+
ScalarEvolution *SE) {
8336+
const SCEV *Range = SE->getConstant(SE->getUnsignedRangeMax(AddRec) -
8337+
SE->getUnsignedRangeMin(AddRec));
8338+
const SCEV *Step = AddRec->getStepRecurrence(*SE);
8339+
if (SE->isKnownNegative(Step))
8340+
Step = SE->getNegativeSCEV(Step);
8341+
return SE->getUDivCeilSCEV(Range, Step);
8342+
}
8343+
8344+
/// Get the underlying SCEVAddExpr from a cast expression if possible.
8345+
const SCEV *peelCastExpr(const SCEVCastExpr *S, ScalarEvolution *SE) {
8346+
const SCEV *Op = S->getOperand();
8347+
if (isa<SCEVCouldNotCompute>(Op))
8348+
return Op;
8349+
if (isa<SCEVAddExpr>(Op))
8350+
return Op;
8351+
if (isa<SCEVCastExpr>(Op))
8352+
return peelCastExpr(cast<SCEVCastExpr>(Op), SE);
8353+
return SE->getCouldNotCompute();
8354+
}
8355+
8356+
static Value *peelExt(Value *V) {
8357+
if (isa<ZExtInst>(V) || isa<SExtInst>(V))
8358+
return peelExt(cast<Instruction>(V)->getOperand(0));
8359+
return V;
8360+
}
8361+
8362+
static bool isIndexInductionVariable(PHINode *InductionVar, Value *Index) {
8363+
if (InductionVar == Index)
8364+
return true;
8365+
if (peelExt(Index) == InductionVar)
8366+
return true;
8367+
return false;
8368+
}
8369+
8370+
/// Check whether the index can wrap and if we can still infer max trip count
8371+
/// given the max trip count inferred from memory access.
8372+
static const SCEV *checkIndexRange(Value *Ptr, PHINode *InductionVar,
8373+
ScalarEvolution *SE,
8374+
const SCEVConstant *MaxExecCount) {
8375+
SmallVector<const SCEV *> InferCountColl;
8376+
auto *PtrGEP = dyn_cast<GetElementPtrInst>(Ptr);
8377+
if (!PtrGEP)
8378+
return SE->getCouldNotCompute();
8379+
for (Value *Index : PtrGEP->indices()) {
8380+
Value *V = Index;
8381+
if (!isIndexInductionVariable(InductionVar, Index))
8382+
continue;
8383+
if (isa<ZExtInst>(V) || isa<SExtInst>(V))
8384+
V = cast<Instruction>(Index)->getOperand(0);
8385+
auto *SCEV = SE->getSCEV(V);
8386+
if (isa<SCEVCouldNotCompute>(SCEV))
8387+
return SE->getCouldNotCompute();
8388+
if (isa<SCEVCastExpr>(SCEV)) {
8389+
SCEV = peelCastExpr(cast<SCEVCastExpr>(SCEV), SE);
8390+
if (isa<SCEVCouldNotCompute>(SCEV))
8391+
return SE->getCouldNotCompute();
8392+
}
8393+
auto *AddRec = dyn_cast<SCEVAddRecExpr>(SCEV);
8394+
if (!AddRec)
8395+
return SE->getCouldNotCompute();
8396+
auto *IndexRange = getIndexRange(AddRec, SE);
8397+
auto *IndexRangeC = dyn_cast<SCEVConstant>(IndexRange);
8398+
if (!IndexRangeC)
8399+
return SE->getCouldNotCompute();
8400+
InferCountColl.push_back(IndexRange);
8401+
break;
8402+
}
8403+
8404+
if (InferCountColl.empty())
8405+
return SE->getCouldNotCompute();
8406+
8407+
InferCountColl.push_back(MaxExecCount);
8408+
8409+
return SE->getUMinFromMismatchedTypes(InferCountColl);
8410+
}
8411+
8412+
ScalarEvolution::ExitLimit
8413+
ScalarEvolution::computeExitLimitFromMemAccessImpl(const Loop *L) {
8414+
SmallVector<Instruction *, 4> MemInsts;
8415+
collectExecLoadStoreInsideLoop(L, DT, MemInsts);
8416+
8417+
SmallVector<const SCEV *> InferCountColl;
8418+
const DataLayout &DL = getDataLayout();
8419+
8420+
for (Instruction *I : MemInsts) {
8421+
Value *Ptr = getLoadStorePointerOperand(I);
8422+
assert(Ptr && "empty pointer operand");
8423+
auto *AddRec = dyn_cast<SCEVAddRecExpr>(getSCEV(Ptr));
8424+
if (!AddRec || !AddRec->isAffine())
8425+
continue;
8426+
const SCEV *PtrBase = getPointerBase(AddRec);
8427+
const SCEV *Step = AddRec->getStepRecurrence(*this);
8428+
const SCEV *MemSize =
8429+
getCertainSizeOfMem(PtrBase, Step->getType(), DL, TLI, this);
8430+
if (!MemSize)
8431+
continue;
8432+
if (isKnownNegative(Step))
8433+
Step = getNegativeSCEV(Step);
8434+
// Now we can infer a max execution time by MemLength/StepLength.
8435+
auto *MaxExecCount = dyn_cast<SCEVConstant>(getUDivCeilSCEV(MemSize, Step));
8436+
if (!MaxExecCount || MaxExecCount->getAPInt().getActiveBits() > 32)
8437+
continue;
8438+
auto *Res = checkIndexRange(Ptr, L->getInductionVariable(*this), this,
8439+
MaxExecCount);
8440+
if (isa<SCEVCouldNotCompute>(Res))
8441+
continue;
8442+
InferCountColl.push_back(Res);
8443+
}
8444+
8445+
if (InferCountColl.empty())
8446+
return getCouldNotCompute();
8447+
8448+
const SCEV *Count = getUMinFromMismatchedTypes(InferCountColl);
8449+
8450+
return {getCouldNotCompute(), Count, Count, /*MaxOrZero=*/false};
8451+
}
8452+
8453+
ScalarEvolution::ExitLimit
8454+
ScalarEvolution::computeExitLimitFromMemAccessCached(ExitLimitCacheTy &Cache,
8455+
const Loop *L) {
8456+
// We don't really need them but the cache does.
8457+
constexpr Value *ExitCond = nullptr;
8458+
constexpr const bool ExitIfTrue = true;
8459+
constexpr const bool ControlsOnlyExit = true;
8460+
constexpr const bool AllowPredicates = true;
8461+
8462+
if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
8463+
AllowPredicates))
8464+
return *MaybeEL;
8465+
8466+
ExitLimit EL = computeExitLimitFromMemAccessImpl(L);
8467+
Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
8468+
return EL;
8469+
}
8470+
82638471
const SCEV *ScalarEvolution::getExitCount(const Loop *L,
82648472
const BasicBlock *ExitingBlock,
82658473
ExitCountKind Kind) {
@@ -8842,6 +9050,16 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
88429050
if (!Latch || !DT.dominates(ExitingBlock, Latch))
88439051
return getCouldNotCompute();
88449052

9053+
// FIXME: To make the case more typical, we only analyze loops that have one
9054+
// exiting block and the block must be the latch. It is easier to capture
9055+
// loops with memory access that will be executed in every iteration.
9056+
const SCEV *PotentiallyBetterConstantMax = getCouldNotCompute();
9057+
if (UseMemoryAccessUBForBEInference && Latch == L->getExitingBlock()) {
9058+
assert(Latch == ExitingBlock);
9059+
auto EL = computeExitLimitFromMemAccess(L);
9060+
PotentiallyBetterConstantMax = EL.ConstantMaxNotTaken;
9061+
}
9062+
88459063
bool IsOnlyExit = (L->getExitingBlock() != nullptr);
88469064
Instruction *Term = ExitingBlock->getTerminator();
88479065
if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
@@ -8850,9 +9068,14 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
88509068
assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
88519069
"It should have one successor in loop and one exit block!");
88529070
// Proceed to the next level to examine the exit condition expression.
8853-
return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
8854-
/*ControlsOnlyExit=*/IsOnlyExit,
8855-
AllowPredicates);
9071+
ExitLimit EL = computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
9072+
/*ControlsOnlyExit=*/IsOnlyExit,
9073+
AllowPredicates);
9074+
if (!isa<SCEVCouldNotCompute>(EL.ConstantMaxNotTaken) &&
9075+
!isa<SCEVCouldNotCompute>(PotentiallyBetterConstantMax))
9076+
EL.ConstantMaxNotTaken = getUMinFromMismatchedTypes(
9077+
EL.ConstantMaxNotTaken, PotentiallyBetterConstantMax);
9078+
return EL;
88569079
}
88579080

88589081
if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
@@ -8865,9 +9088,14 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
88659088
Exit = SBB;
88669089
}
88679090
assert(Exit && "Exiting block must have at least one exit");
8868-
return computeExitLimitFromSingleExitSwitch(
8869-
L, SI, Exit,
8870-
/*ControlsOnlyExit=*/IsOnlyExit);
9091+
ExitLimit EL =
9092+
computeExitLimitFromSingleExitSwitch(L, SI, Exit,
9093+
/*ControlsOnlyExit=*/IsOnlyExit);
9094+
if (!isa<SCEVCouldNotCompute>(EL.ConstantMaxNotTaken) &&
9095+
!isa<SCEVCouldNotCompute>(PotentiallyBetterConstantMax))
9096+
EL.ConstantMaxNotTaken = getUMinFromMismatchedTypes(
9097+
EL.ConstantMaxNotTaken, PotentiallyBetterConstantMax);
9098+
return EL;
88719099
}
88729100

88739101
return getCouldNotCompute();
@@ -8881,6 +9109,13 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond(
88819109
ControlsOnlyExit, AllowPredicates);
88829110
}
88839111

9112+
ScalarEvolution::ExitLimit
9113+
ScalarEvolution::computeExitLimitFromMemAccess(const Loop *L) {
9114+
ScalarEvolution::ExitLimitCacheTy Cache(L, /* ExitIfTrue */ true,
9115+
/* AllowPredicates */ true);
9116+
return computeExitLimitFromMemAccessCached(Cache, L);
9117+
}
9118+
88849119
std::optional<ScalarEvolution::ExitLimit>
88859120
ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
88869121
bool ExitIfTrue, bool ControlsOnlyExit,
@@ -13579,6 +13814,17 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
1357913814
OS << ": ";
1358013815
OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
1358113816
}
13817+
13818+
if (UseMemoryAccessUBForBEInference) {
13819+
unsigned SmallMaxTrip = SE->getSmallConstantMaxTripCount(L);
13820+
OS << "Loop ";
13821+
L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13822+
OS << ": ";
13823+
if (SmallMaxTrip)
13824+
OS << "Small constant max trip is " << SmallMaxTrip << "\n";
13825+
else
13826+
OS << "Small constant max trip couldn't be computed.\n";
13827+
}
1358213828
}
1358313829

1358413830
namespace llvm {

0 commit comments

Comments
 (0)