Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions llvm/include/llvm/Support/BranchProbability.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ class BranchProbability {
/// \return \c Num divided by \c this.
LLVM_ABI uint64_t scaleByInverse(uint64_t Num) const;

/// Compute pow(Probability, N).
BranchProbability pow(unsigned N) const;

BranchProbability &operator+=(BranchProbability RHS) {
assert(N != UnknownN && RHS.N != UnknownN &&
"Unknown probability cannot participate in arithmetics.");
Expand Down
34 changes: 34 additions & 0 deletions llvm/include/llvm/Transforms/Utils/LoopUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,40 @@ LLVM_ABI bool setLoopEstimatedTripCount(
Loop *L, unsigned EstimatedTripCount,
std::optional<unsigned> EstimatedLoopInvocationWeight = std::nullopt);

/// Based on branch weight metadata, return either:
/// - An unknown probability if the implementation is unable to handle the loop
/// form of \p L (e.g., \p L must have a latch block that controls the loop
/// exit).
/// - The probability \c P that, at the end of any iteration, the latch of \p L
/// will start another iteration such that `1 - P` is the probability of
/// exiting the loop.
BranchProbability getLoopProbability(Loop *L);

/// Set branch weight metadata for the latch of \p L to indicate that, at the
/// end of any iteration, \p P and `1 - P` are the probabilities of starting
/// another iteration and exiting the loop, respectively. Return false if the
/// implementation is unable to handle the loop form of \p L (e.g., \p L must
/// have a latch block that controls the loop exit). Otherwise, return true.
bool setLoopProbability(Loop *L, BranchProbability P);

/// Based on branch weight metadata, return either:
/// - An unknown probability if the implementation cannot extract the
/// probability (e.g., \p B must have exactly two target labels, so it must be
/// a conditional branch).
/// - The probability \c P that control flows from \p B to its first target
/// label such that `1 - P` is the probability of control flowing to its
/// second target label, or vice-versa if \p ForFirstTarget is false.
BranchProbability getBranchProbability(BranchInst *B, bool ForFirstTarget);

/// Set branch weight metadata for \p B to indicate that \p P and `1 - P` are
/// the probabilities of control flowing to its first and second target labels,
/// respectively, or vice-versa if \p ForFirstTarget is false. Return false if
/// the implementation cannot set the probability (e.g., \p B must have exactly
/// two target labels, so it must be a conditional branch). Otherwise, return
/// true.
bool setBranchProbability(BranchInst *B, BranchProbability P,
bool ForFirstTarget);

/// Check inner loop (L) backedge count is known to be invariant on all
/// iterations of its outer loop. If the loop has no parent, this is trivially
/// true.
Expand Down
4 changes: 3 additions & 1 deletion llvm/include/llvm/Transforms/Utils/UnrollLoop.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ LLVM_ABI bool UnrollRuntimeLoopRemainder(
LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, AssumptionCache *AC,
const TargetTransformInfo *TTI, bool PreserveLCSSA,
unsigned SCEVExpansionBudget, bool RuntimeUnrollMultiExit,
Loop **ResultLoop = nullptr);
Loop **ResultLoop = nullptr,
std::optional<unsigned> OriginalTripCount = std::nullopt,
BranchProbability OriginalLoopProb = BranchProbability::getUnknown());

LLVM_ABI LoopUnrollResult UnrollAndJamLoop(
Loop *L, unsigned Count, unsigned TripCount, unsigned TripMultiple,
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Support/BranchProbability.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,10 @@ uint64_t BranchProbability::scale(uint64_t Num) const {
uint64_t BranchProbability::scaleByInverse(uint64_t Num) const {
return ::scale<0>(Num, D, N);
}

BranchProbability BranchProbability::pow(unsigned N) const {
BranchProbability Res = BranchProbability::getOne();
for (unsigned I = 0; I < N; ++I)
Res *= *this;
return Res;
}
30 changes: 19 additions & 11 deletions llvm/lib/Transforms/Utils/LoopUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
const bool MaxOrZero = SE->isBackedgeTakenCountMaxOrZero(L);
std::optional<unsigned> OriginalTripCount =
llvm::getLoopEstimatedTripCount(L);
BranchProbability OriginalLoopProb = llvm::getLoopProbability(L);

// Effectively "DCE" unrolled iterations that are beyond the max tripcount
// and will never be executed.
Expand Down Expand Up @@ -591,11 +592,11 @@ llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
: isEpilogProfitable(L);

if (ULO.Runtime &&
!UnrollRuntimeLoopRemainder(L, ULO.Count, ULO.AllowExpensiveTripCount,
EpilogProfitability, ULO.UnrollRemainder,
ULO.ForgetAllSCEV, LI, SE, DT, AC, TTI,
PreserveLCSSA, ULO.SCEVExpansionBudget,
ULO.RuntimeUnrollMultiExit, RemainderLoop)) {
!UnrollRuntimeLoopRemainder(
L, ULO.Count, ULO.AllowExpensiveTripCount, EpilogProfitability,
ULO.UnrollRemainder, ULO.ForgetAllSCEV, LI, SE, DT, AC, TTI,
PreserveLCSSA, ULO.SCEVExpansionBudget, ULO.RuntimeUnrollMultiExit,
RemainderLoop, OriginalTripCount, OriginalLoopProb)) {
if (ULO.Force)
ULO.Runtime = false;
else {
Expand Down Expand Up @@ -1130,13 +1131,13 @@ llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
LI->erase(L);
// We shouldn't try to use `L` anymore.
L = nullptr;
} else if (OriginalTripCount) {
} else {
// Update metadata for the loop's branch weights and estimated trip count:
// - If ULO.Runtime, UnrollRuntimeLoopRemainder sets the guard branch
// weights, latch branch weights, and estimated trip count of the
// remainder loop it creates. It also sets the branch weights for the
// unrolled loop guard it creates. The branch weights for the unrolled
// loop latch are adjusted below. FIXME: Actually handle ULO.Runtime.
// loop latch are adjusted below. FIXME: Handle prologue loops.
// - Otherwise, if unrolled loop iteration latches become unconditional,
// branch weights are adjusted above. FIXME: Actually handle such
// unconditional latches.
Expand All @@ -1159,10 +1160,17 @@ llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
// the unrolled loop as a whole without considering the branch weights for
// each unrolled iteration's latch within it, we store the new trip count as
// separate metadata.
unsigned NewTripCount = *OriginalTripCount / ULO.Count;
if (!ULO.Runtime && *OriginalTripCount % ULO.Count)
NewTripCount += 1;
setLoopEstimatedTripCount(L, NewTripCount);
if (!OriginalLoopProb.isUnknown() && ULO.Runtime && EpilogProfitability) {
// Where p is always the probability of executing at least 1 more
// iteration, the probability for at least n more iterations is p^n.
setLoopProbability(L, OriginalLoopProb.pow(ULO.Count));
}
if (OriginalTripCount) {
unsigned NewTripCount = *OriginalTripCount / ULO.Count;
if (!ULO.Runtime && *OriginalTripCount % ULO.Count)
NewTripCount += 1;
setLoopEstimatedTripCount(L, NewTripCount);
}
}

// LoopInfo should not be valid, confirm that.
Expand Down
100 changes: 82 additions & 18 deletions llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
#include "llvm/Transforms/Utils/UnrollLoop.h"
#include <cmath>

using namespace llvm;

Expand Down Expand Up @@ -195,6 +196,21 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count,
}
}

/// Assume, due to our position in the remainder loop or its guard, anywhere
/// from 0 to \p N more iterations can possibly execute. Among such cases in
/// the original loop (with loop probability \p OriginalLoopProb), what is the
/// probability of executing at least one more iteration?
static BranchProbability
probOfNextInRemainder(BranchProbability OriginalLoopProb, unsigned N) {
// Each of these variables holds the original loop's probability that the
// number of iterations it will execute is some m in the specified range.
BranchProbability ProbOne = OriginalLoopProb; // 1 <= m
BranchProbability ProbTooMany = ProbOne.pow(N + 1); // N + 1 <= m
BranchProbability ProbNotTooMany = ProbTooMany.getCompl(); // 0 <= m <= N
BranchProbability ProbOneNotTooMany = ProbOne - ProbTooMany; // 1 <= m <= N
return ProbOneNotTooMany / ProbNotTooMany;
}

/// Connect the unrolling epilog code to the original loop.
/// The unrolling epilog code contains code to execute the
/// 'extra' iterations if the run-time trip count modulo the
Expand All @@ -221,7 +237,8 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
BasicBlock *EpilogPreHeader, BasicBlock *NewPreHeader,
ValueToValueMapTy &VMap, DominatorTree *DT,
LoopInfo *LI, bool PreserveLCSSA, ScalarEvolution &SE,
unsigned Count, AssumptionCache &AC) {
unsigned Count, AssumptionCache &AC,
BranchProbability OriginalLoopProb) {
BasicBlock *Latch = L->getLoopLatch();
assert(Latch && "Loop must have a latch");
BasicBlock *EpilogLatch = cast<BasicBlock>(VMap[Latch]);
Expand Down Expand Up @@ -332,12 +349,19 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
PreserveLCSSA);
// Add the branch to the exit block (around the epilog loop)
MDNode *BranchWeights = nullptr;
if (hasBranchWeightMD(*Latch->getTerminator())) {
if (OriginalLoopProb.isUnknown() &&
hasBranchWeightMD(*Latch->getTerminator())) {
// Assume equal distribution in interval [0, Count).
MDBuilder MDB(B.getContext());
BranchWeights = MDB.createBranchWeights(1, Count - 1);
}
B.CreateCondBr(BrLoopExit, EpilogPreHeader, Exit, BranchWeights);
BranchInst *RemainderLoopGuard =
B.CreateCondBr(BrLoopExit, EpilogPreHeader, Exit, BranchWeights);
if (!OriginalLoopProb.isUnknown()) {
setBranchProbability(RemainderLoopGuard,
probOfNextInRemainder(OriginalLoopProb, Count - 1),
/*ForFirstTarget=*/true);
}
InsertPt->eraseFromParent();
if (DT) {
auto *NewDom = DT->findNearestCommonDominator(Exit, NewExit);
Expand All @@ -357,14 +381,15 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
/// The cloned blocks should be inserted between InsertTop and InsertBot.
/// InsertTop should be new preheader, InsertBot new loop exit.
/// Returns the new cloned loop that is created.
static Loop *
CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,
const bool UnrollRemainder,
BasicBlock *InsertTop,
BasicBlock *InsertBot, BasicBlock *Preheader,
static Loop *CloneLoopBlocks(Loop *L, Value *NewIter,
const bool UseEpilogRemainder,
const bool UnrollRemainder, BasicBlock *InsertTop,
BasicBlock *InsertBot, BasicBlock *Preheader,
std::vector<BasicBlock *> &NewBlocks,
LoopBlocksDFS &LoopBlocks, ValueToValueMapTy &VMap,
DominatorTree *DT, LoopInfo *LI, unsigned Count) {
DominatorTree *DT, LoopInfo *LI, unsigned Count,
std::optional<unsigned> OriginalTripCount,
BranchProbability OriginalLoopProb) {
StringRef suffix = UseEpilogRemainder ? "epil" : "prol";
BasicBlock *Header = L->getHeader();
BasicBlock *Latch = L->getLoopLatch();
Expand Down Expand Up @@ -419,7 +444,8 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,
Builder.CreateAdd(NewIdx, One, NewIdx->getName() + ".next");
Value *IdxCmp = Builder.CreateICmpNE(IdxNext, NewIter, NewIdx->getName() + ".cmp");
MDNode *BranchWeights = nullptr;
if (hasBranchWeightMD(*LatchBR)) {
if ((OriginalLoopProb.isUnknown() || !UseEpilogRemainder) &&
hasBranchWeightMD(*LatchBR)) {
uint32_t ExitWeight;
uint32_t BackEdgeWeight;
if (Count >= 3) {
Expand All @@ -437,7 +463,29 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,
MDBuilder MDB(Builder.getContext());
BranchWeights = MDB.createBranchWeights(BackEdgeWeight, ExitWeight);
}
Builder.CreateCondBr(IdxCmp, FirstLoopBB, InsertBot, BranchWeights);
BranchInst *RemainderLoopLatch =
Builder.CreateCondBr(IdxCmp, FirstLoopBB, InsertBot, BranchWeights);
if (!OriginalLoopProb.isUnknown() && UseEpilogRemainder) {
// Compute the total frequency of the original loop body from the
// remainder iterations. Once we've reached them, the first of them
// always executes, so its frequency and probability are 1.
double FreqRemIters = 1;
if (Count > 2) {
BranchProbability ProbReaching = BranchProbability::getOne();
for (unsigned N = Count - 2; N >= 1; --N) {
ProbReaching *= probOfNextInRemainder(OriginalLoopProb, N);
FreqRemIters += double(ProbReaching.getNumerator()) /
ProbReaching.getDenominator();
}
}
// Solve for the loop probability that would produce that frequency.
// Sum(i=0..inf)(Prob^i) = 1/(1-Prob) = FreqRemIters.
double ProbDouble = 1 - 1 / FreqRemIters;
BranchProbability Prob = BranchProbability::getBranchProbability(
std::round(ProbDouble * BranchProbability::getDenominator()),
BranchProbability::getDenominator());
setBranchProbability(RemainderLoopLatch, Prob, /*ForFirstTarget=*/true);
}
NewIdx->addIncoming(Zero, InsertTop);
NewIdx->addIncoming(IdxNext, NewBB);
LatchBR->eraseFromParent();
Expand Down Expand Up @@ -469,6 +517,8 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,

std::optional<MDNode *> NewLoopID = makeFollowupLoopID(
LoopID, {LLVMLoopUnrollFollowupAll, LLVMLoopUnrollFollowupRemainder});
if (OriginalTripCount && UseEpilogRemainder)
setLoopEstimatedTripCount(NewLoop, *OriginalTripCount % Count);
if (NewLoopID) {
NewLoop->setLoopID(*NewLoopID);

Expand Down Expand Up @@ -603,7 +653,8 @@ bool llvm::UnrollRuntimeLoopRemainder(
LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, AssumptionCache *AC,
const TargetTransformInfo *TTI, bool PreserveLCSSA,
unsigned SCEVExpansionBudget, bool RuntimeUnrollMultiExit,
Loop **ResultLoop) {
Loop **ResultLoop, std::optional<unsigned> OriginalTripCount,
BranchProbability OriginalLoopProb) {
LLVM_DEBUG(dbgs() << "Trying runtime unrolling on Loop: \n");
LLVM_DEBUG(L->dump());
LLVM_DEBUG(UseEpilogRemainder ? dbgs() << "Using epilog remainder.\n"
Expand Down Expand Up @@ -823,12 +874,23 @@ bool llvm::UnrollRuntimeLoopRemainder(
BasicBlock *UnrollingLoop = UseEpilogRemainder ? NewPreHeader : PrologExit;
// Branch to either remainder (extra iterations) loop or unrolling loop.
MDNode *BranchWeights = nullptr;
if (hasBranchWeightMD(*Latch->getTerminator())) {
if ((OriginalLoopProb.isUnknown() || !UseEpilogRemainder) &&
hasBranchWeightMD(*Latch->getTerminator())) {
// Assume loop is nearly always entered.
MDBuilder MDB(B.getContext());
BranchWeights = MDB.createBranchWeights(EpilogHeaderWeights);
}
B.CreateCondBr(BranchVal, RemainderLoop, UnrollingLoop, BranchWeights);
BranchInst *UnrollingLoopGuard =
B.CreateCondBr(BranchVal, RemainderLoop, UnrollingLoop, BranchWeights);
if (!OriginalLoopProb.isUnknown() && UseEpilogRemainder) {
// The original loop's first iteration always happens. Compute the
// probability of the original loop executing Count-1 iterations after that
// to complete the first iteration of the unrolled loop.
BranchProbability ProbOne = OriginalLoopProb;
BranchProbability ProbRest = ProbOne.pow(Count - 1);
setBranchProbability(UnrollingLoopGuard, ProbRest,
/*ForFirstTarget=*/false);
}
PreHeaderBR->eraseFromParent();
if (DT) {
if (UseEpilogRemainder)
Expand All @@ -855,9 +917,10 @@ bool llvm::UnrollRuntimeLoopRemainder(
// iterations. This function adds the appropriate CFG connections.
BasicBlock *InsertBot = UseEpilogRemainder ? LatchExit : PrologExit;
BasicBlock *InsertTop = UseEpilogRemainder ? EpilogPreHeader : PrologPreHeader;
Loop *remainderLoop = CloneLoopBlocks(
L, ModVal, UseEpilogRemainder, UnrollRemainder, InsertTop, InsertBot,
NewPreHeader, NewBlocks, LoopBlocks, VMap, DT, LI, Count);
Loop *remainderLoop =
CloneLoopBlocks(L, ModVal, UseEpilogRemainder, UnrollRemainder, InsertTop,
InsertBot, NewPreHeader, NewBlocks, LoopBlocks, VMap, DT,
LI, Count, OriginalTripCount, OriginalLoopProb);

// Insert the cloned blocks into the function.
F->splice(InsertBot->getIterator(), F, NewBlocks[0]->getIterator(), F->end());
Expand Down Expand Up @@ -956,7 +1019,8 @@ bool llvm::UnrollRuntimeLoopRemainder(
// Connect the epilog code to the original loop and update the
// PHI functions.
ConnectEpilog(L, ModVal, NewExit, LatchExit, PreHeader, EpilogPreHeader,
NewPreHeader, VMap, DT, LI, PreserveLCSSA, *SE, Count, *AC);
NewPreHeader, VMap, DT, LI, PreserveLCSSA, *SE, Count, *AC,
OriginalLoopProb);

// Update counter in loop for unrolling.
// Use an incrementing IV. Pre-incr/post-incr is backedge/trip count.
Expand Down
43 changes: 43 additions & 0 deletions llvm/lib/Transforms/Utils/LoopUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,49 @@ bool llvm::setLoopEstimatedTripCount(
return true;
}

BranchProbability llvm::getLoopProbability(Loop *L) {
BranchInst *LatchBranch = getExpectedExitLoopLatchBranch(L);
if (!LatchBranch)
return BranchProbability::getUnknown();
bool FirstTargetIsLoop = LatchBranch->getSuccessor(0) == L->getHeader();
return getBranchProbability(LatchBranch, FirstTargetIsLoop);
}

bool llvm::setLoopProbability(Loop *L, BranchProbability P) {
BranchInst *LatchBranch = getExpectedExitLoopLatchBranch(L);
if (!LatchBranch)
return false;
bool FirstTargetIsLoop = LatchBranch->getSuccessor(0) == L->getHeader();
return setBranchProbability(LatchBranch, P, FirstTargetIsLoop);
}

BranchProbability llvm::getBranchProbability(BranchInst *B,
bool ForFirstTarget) {
if (B->getNumSuccessors() != 2)
return BranchProbability::getUnknown();
uint64_t Weight0, Weight1;
if (!extractBranchWeights(*B, Weight0, Weight1))
return BranchProbability::getUnknown();
if (!ForFirstTarget)
std::swap(Weight0, Weight1);
return BranchProbability::getBranchProbability(Weight0, Weight0 + Weight1);
}

bool llvm::setBranchProbability(BranchInst *B, BranchProbability P,
bool ForFirstTarget) {
if (B->getNumSuccessors() != 2)
return false;
BranchProbability Prob0 = P;
BranchProbability Prob1 = P.getCompl();
if (!ForFirstTarget)
std::swap(Prob0, Prob1);
MDBuilder MDB(B->getContext());
B->setMetadata(
LLVMContext::MD_prof,
MDB.createBranchWeights(Prob0.getNumerator(), Prob1.getNumerator()));
return true;
}

bool llvm::hasIterationCountInvariantInParent(Loop *InnerLoop,
ScalarEvolution &SE) {
Loop *OuterL = InnerLoop->getParentLoop();
Expand Down
Loading