Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions llvm/include/llvm/Analysis/ScalarEvolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,10 @@ class ScalarEvolution {
bool ExitIfTrue, bool ControlsOnlyExit,
bool AllowPredicates = false);

/// Compute the number of times the body of the specific loop will execute via
/// the memory access inside the loop body.
ExitLimit computeExitLimitFromMemAccess(const Loop *L);

/// A predicate is said to be monotonically increasing if may go from being
/// false to being true as the loop iterates, but never the other way
/// around. A predicate is said to be monotonically decreasing if may go
Expand Down Expand Up @@ -1805,6 +1809,9 @@ class ScalarEvolution {
Value *ExitCond, bool ExitIfTrue,
bool ControlsOnlyExit,
bool AllowPredicates);
ExitLimit computeExitLimitFromMemAccessCached(ExitLimitCacheTy &Cache,
const Loop *L);
ExitLimit computeExitLimitFromMemAccessImpl(const Loop *L);
std::optional<ScalarEvolution::ExitLimit> computeExitLimitFromCondFromBinOp(
ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
bool ControlsOnlyExit, bool AllowPredicates);
Expand Down
258 changes: 252 additions & 6 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,10 @@ static cl::opt<bool> UseContextForNoWrapFlagInference(
cl::desc("Infer nuw/nsw flags using context where suitable"),
cl::init(true));

static cl::opt<bool> UseMemoryAccessUBForBEInference(
"scalar-evolution-infer-max-trip-count-from-memory-access", cl::Hidden,
cl::desc("Infer loop max trip count from memory access"), cl::init(false));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check the compile time impact (via @nikic tracker) when you enable this?


//===----------------------------------------------------------------------===//
// SCEV class definitions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -8260,6 +8264,210 @@ ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
return getSmallConstantTripMultiple(L, ExitCount);
}

/// Collect all load/store instructions that must be executed in every iteration
/// of loop \p L .
static void
collectExecLoadStoreInsideLoop(const Loop *L, DominatorTree &DT,
SmallVector<Instruction *, 4> &MemInsts) {
// It is difficult to tell if the load/store instruction is executed on every
// iteration inside an irregular loop.
if (!L->isLoopSimplifyForm() || !L->isInnermost())
return;

const BasicBlock *LoopLatch = L->getLoopLatch();
assert(LoopLatch && "normal form loop doesn't have a latch");
assert(L->getExitingBlock() == LoopLatch);

// We will not continue if sanitizer is enabled.
const Function *F = LoopLatch->getParent();
if (F->hasFnAttribute(Attribute::SanitizeAddress) ||
F->hasFnAttribute(Attribute::SanitizeThread) ||
F->hasFnAttribute(Attribute::SanitizeMemory) ||
F->hasFnAttribute(Attribute::SanitizeHWAddress) ||
F->hasFnAttribute(Attribute::SanitizeMemTag))
return;

for (auto *BB : L->getBlocks()) {
// We need to make sure that max execution time of MemAccessBB in loop
// represents latch max excution time. The BB below should be skipped:
// Entry
// │
// ┌─────▼─────┐
// │Loop Header◄─────┐
// └──┬──────┬─┘ │
// │ │ │
// ┌────────▼──┐ ┌─▼─────┐ │
// │MemAccessBB│ │OtherBB│ │
// └────────┬──┘ └─┬─────┘ │
// │ │ │
// ┌─▼──────▼─┐ │
// │Loop Latch├─────┘
// └────┬─────┘
// ▼
// Exit
if (!DT.dominates(BB, LoopLatch))
continue;

for (Instruction &I : *BB) {
if (isa<LoadInst>(&I) || isa<StoreInst>(&I))
MemInsts.push_back(&I);
}
}
}

/// Return a SCEV representing the memory size of pointer \p V .
static const SCEV *getCertainSizeOfMem(const SCEV *V, Type *RTy,
const DataLayout &DL,
const TargetLibraryInfo &TLI,
ScalarEvolution *SE) {
const SCEVUnknown *PtrBase = dyn_cast<SCEVUnknown>(V);
if (!PtrBase)
return nullptr;
Value *Ptr = PtrBase->getValue();
uint64_t Size = 0;
if (!llvm::getObjectSize(Ptr, Size, DL, &TLI))
return nullptr;
return SE->getConstant(RTy, Size);
}

/// Get the range of given index represented by \p AddRec.
static const SCEV *getIndexRange(const SCEVAddRecExpr *AddRec,
ScalarEvolution *SE) {
const SCEV *Range = SE->getConstant(SE->getUnsignedRangeMax(AddRec) -
SE->getUnsignedRangeMin(AddRec));
const SCEV *Step = AddRec->getStepRecurrence(*SE);
if (SE->isKnownNegative(Step))
Step = SE->getNegativeSCEV(Step);
return SE->getUDivCeilSCEV(Range, Step);
}

/// Get the underlying SCEVAddExpr from a cast expression if possible.
const SCEV *peelCastExpr(const SCEVCastExpr *S, ScalarEvolution *SE) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have existing utility that can already do this?

const SCEV *Op = S->getOperand();
if (isa<SCEVCouldNotCompute>(Op))
return Op;
if (isa<SCEVAddExpr>(Op))
return Op;
if (isa<SCEVCastExpr>(Op))
return peelCastExpr(cast<SCEVCastExpr>(Op), SE);
return SE->getCouldNotCompute();
}

static Value *peelExt(Value *V) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have existing utility that can already do this?

if (isa<ZExtInst>(V) || isa<SExtInst>(V))
return peelExt(cast<Instruction>(V)->getOperand(0));
return V;
}

static bool isIndexInductionVariable(PHINode *InductionVar, Value *Index) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have existing utility that can already do this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be too conservative. It can't handle uses such as array[i * 2]. What is the best way to do that?

if (InductionVar == Index)
return true;
if (peelExt(Index) == InductionVar)
return true;
return false;
}

/// Check whether the index can wrap and if we can still infer max trip count
/// given the max trip count inferred from memory access.
static const SCEV *checkIndexRange(Value *Ptr, PHINode *InductionVar,
ScalarEvolution *SE,
const SCEVConstant *MaxExecCount) {
SmallVector<const SCEV *> InferCountColl;
auto *PtrGEP = dyn_cast<GetElementPtrInst>(Ptr);
if (!PtrGEP)
return SE->getCouldNotCompute();
for (Value *Index : PtrGEP->indices()) {
Value *V = Index;
if (!isIndexInductionVariable(InductionVar, Index))
continue;
if (isa<ZExtInst>(V) || isa<SExtInst>(V))
V = cast<Instruction>(Index)->getOperand(0);
auto *SCEV = SE->getSCEV(V);
if (isa<SCEVCouldNotCompute>(SCEV))
return SE->getCouldNotCompute();
if (isa<SCEVCastExpr>(SCEV)) {
SCEV = peelCastExpr(cast<SCEVCastExpr>(SCEV), SE);
if (isa<SCEVCouldNotCompute>(SCEV))
return SE->getCouldNotCompute();
}
auto *AddRec = dyn_cast<SCEVAddRecExpr>(SCEV);
if (!AddRec)
return SE->getCouldNotCompute();
auto *IndexRange = getIndexRange(AddRec, SE);
auto *IndexRangeC = dyn_cast<SCEVConstant>(IndexRange);
if (!IndexRangeC)
return SE->getCouldNotCompute();
InferCountColl.push_back(IndexRange);
break;
}

if (InferCountColl.empty())
return SE->getCouldNotCompute();

InferCountColl.push_back(MaxExecCount);

return SE->getUMinFromMismatchedTypes(InferCountColl);
}

ScalarEvolution::ExitLimit
ScalarEvolution::computeExitLimitFromMemAccessImpl(const Loop *L) {
SmallVector<Instruction *, 4> MemInsts;
collectExecLoadStoreInsideLoop(L, DT, MemInsts);

SmallVector<const SCEV *> InferCountColl;
const DataLayout &DL = getDataLayout();

for (Instruction *I : MemInsts) {
Value *Ptr = getLoadStorePointerOperand(I);
assert(Ptr && "empty pointer operand");
auto *AddRec = dyn_cast<SCEVAddRecExpr>(getSCEV(Ptr));
if (!AddRec || !AddRec->isAffine())
continue;
const SCEV *PtrBase = getPointerBase(AddRec);
const SCEV *Step = AddRec->getStepRecurrence(*this);
const SCEV *MemSize =
getCertainSizeOfMem(PtrBase, Step->getType(), DL, TLI, this);
if (!MemSize)
continue;
if (isKnownNegative(Step))
Step = getNegativeSCEV(Step);
// Now we can infer a max execution time by MemLength/StepLength.
auto *MaxExecCount = dyn_cast<SCEVConstant>(getUDivCeilSCEV(MemSize, Step));
if (!MaxExecCount || MaxExecCount->getAPInt().getActiveBits() > 32)
continue;
auto *Res = checkIndexRange(Ptr, L->getInductionVariable(*this), this,
MaxExecCount);
if (isa<SCEVCouldNotCompute>(Res))
continue;
InferCountColl.push_back(Res);
}

if (InferCountColl.empty())
return getCouldNotCompute();

const SCEV *Count = getUMinFromMismatchedTypes(InferCountColl);

return {getCouldNotCompute(), Count, Count, /*MaxOrZero=*/false};
}

ScalarEvolution::ExitLimit
ScalarEvolution::computeExitLimitFromMemAccessCached(ExitLimitCacheTy &Cache,
const Loop *L) {
// We don't really need them but the cache does.
constexpr Value *ExitCond = nullptr;
constexpr const bool ExitIfTrue = true;
constexpr const bool ControlsOnlyExit = true;
constexpr const bool AllowPredicates = true;

if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
AllowPredicates))
return *MaybeEL;

ExitLimit EL = computeExitLimitFromMemAccessImpl(L);
Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
return EL;
}

const SCEV *ScalarEvolution::getExitCount(const Loop *L,
const BasicBlock *ExitingBlock,
ExitCountKind Kind) {
Expand Down Expand Up @@ -8842,6 +9050,16 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
if (!Latch || !DT.dominates(ExitingBlock, Latch))
return getCouldNotCompute();

// FIXME: To make the case more typical, we only analyze loops that have one
// exiting block and the block must be the latch. It is easier to capture
// loops with memory access that will be executed in every iteration.
const SCEV *PotentiallyBetterConstantMax = getCouldNotCompute();
if (UseMemoryAccessUBForBEInference && Latch == L->getExitingBlock()) {
assert(Latch == ExitingBlock);
auto EL = computeExitLimitFromMemAccess(L);
PotentiallyBetterConstantMax = EL.ConstantMaxNotTaken;
}

bool IsOnlyExit = (L->getExitingBlock() != nullptr);
Instruction *Term = ExitingBlock->getTerminator();
if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
Expand All @@ -8850,9 +9068,14 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
"It should have one successor in loop and one exit block!");
// Proceed to the next level to examine the exit condition expression.
return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
/*ControlsOnlyExit=*/IsOnlyExit,
AllowPredicates);
ExitLimit EL = computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
/*ControlsOnlyExit=*/IsOnlyExit,
AllowPredicates);
if (!isa<SCEVCouldNotCompute>(EL.ConstantMaxNotTaken) &&
!isa<SCEVCouldNotCompute>(PotentiallyBetterConstantMax))
EL.ConstantMaxNotTaken = getUMinFromMismatchedTypes(
EL.ConstantMaxNotTaken, PotentiallyBetterConstantMax);
return EL;
}

if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
Expand All @@ -8865,9 +9088,14 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
Exit = SBB;
}
assert(Exit && "Exiting block must have at least one exit");
return computeExitLimitFromSingleExitSwitch(
L, SI, Exit,
/*ControlsOnlyExit=*/IsOnlyExit);
ExitLimit EL =
computeExitLimitFromSingleExitSwitch(L, SI, Exit,
/*ControlsOnlyExit=*/IsOnlyExit);
if (!isa<SCEVCouldNotCompute>(EL.ConstantMaxNotTaken) &&
!isa<SCEVCouldNotCompute>(PotentiallyBetterConstantMax))
EL.ConstantMaxNotTaken = getUMinFromMismatchedTypes(
EL.ConstantMaxNotTaken, PotentiallyBetterConstantMax);
return EL;
}

return getCouldNotCompute();
Expand All @@ -8881,6 +9109,13 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond(
ControlsOnlyExit, AllowPredicates);
}

ScalarEvolution::ExitLimit
ScalarEvolution::computeExitLimitFromMemAccess(const Loop *L) {
ScalarEvolution::ExitLimitCacheTy Cache(L, /* ExitIfTrue */ true,
/* AllowPredicates */ true);
return computeExitLimitFromMemAccessCached(Cache, L);
}

std::optional<ScalarEvolution::ExitLimit>
ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
bool ExitIfTrue, bool ControlsOnlyExit,
Expand Down Expand Up @@ -13579,6 +13814,17 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
OS << ": ";
OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
}

if (UseMemoryAccessUBForBEInference) {
unsigned SmallMaxTrip = SE->getSmallConstantMaxTripCount(L);
OS << "Loop ";
L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
OS << ": ";
if (SmallMaxTrip)
OS << "Small constant max trip is " << SmallMaxTrip << "\n";
else
OS << "Small constant max trip couldn't be computed.\n";
}
}

namespace llvm {
Expand Down
Loading