Skip to content

Commit 5a99593

Browse files
committed
[LoopUnroll] Fix block frequencies for epilogue
As another step in issue #135812, this patch fixes block frequencies for partial loop unrolling with an epilogue remainder loop. It does not fully handle the case when the epilogue loop itself is unrolled. That will be handled in the next patch. For the guard and latch of each of the unrolled loop and epilogue loop, this patch sets branch weights derived directly from the original loop latch branch weights. The total frequency of the original loop body, summed across all its occurrences in the unrolled loop and epilogue loop, is the same as in the original loop. This patch also sets `llvm.loop.estimated_trip_count` for the epilogue loop instead of relying on the epilogue's latch branch weights to imply it. This patch removes the XFAIL directives that PR #157754 added to the test suite.
1 parent 2897e64 commit 5a99593

File tree

10 files changed

+448
-54
lines changed

10 files changed

+448
-54
lines changed

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,38 @@ LLVM_ABI bool setLoopEstimatedTripCount(
365365
Loop *L, unsigned EstimatedTripCount,
366366
std::optional<unsigned> EstimatedLoopInvocationWeight = std::nullopt);
367367

368+
/// Based on branch weight metadata, return either:
369+
/// - \c std::nullopt if the implementation is unable to handle the loop form
370+
/// of \p L (e.g., \p L must have a latch block that controls the loop exit).
371+
/// - Else, the estimated probability that, at the end of any iteration, the
372+
/// latch of \p L will start another iteration. The result \c P is such that
373+
/// `0 <= P <= 1`, and `1 - P` is the probability of exiting the loop.
374+
std::optional<double> getLoopProbability(Loop *L);
375+
376+
/// Set branch weight metadata for the latch of \p L to indicate that, at the
377+
/// end of any iteration, its estimated probability of starting another
378+
/// iteration is \p P. Return false if the implementation is unable to handle
379+
/// the loop form of \p L (e.g., \p L must have a latch block that controls the
380+
/// loop exit). Otherwise, return true.
381+
bool setLoopProbability(Loop *L, double P);
382+
383+
/// Based on branch weight metadata, return either:
384+
/// - \c std::nullopt if the implementation cannot extract the probability
385+
/// (e.g., \p B must have exactly two target labels, so it must be a
386+
/// conditional branch).
387+
/// - The probability \c P that control flows from \p B to its first target
388+
/// label such that `1 - P` is the probability of control flowing to its
389+
/// second target label, or vice-versa if \p ForFirstTarget is false.
390+
std::optional<double> getBranchProbability(BranchInst *B, bool ForFirstTarget);
391+
392+
/// Set branch weight metadata for \p B to indicate that \p P and `1 - P` are
393+
/// the probabilities of control flowing to its first and second target labels,
394+
/// respectively, or vice-versa if \p ForFirstTarget is false. Return false if
395+
/// the implementation cannot set the probability (e.g., \p B must have exactly
396+
/// two target labels, so it must be a conditional branch). Otherwise, return
397+
/// true.
398+
bool setBranchProbability(BranchInst *B, double P, bool ForFirstTarget);
399+
368400
/// Check inner loop (L) backedge count is known to be invariant on all
369401
/// iterations of its outer loop. If the loop has no parent, this is trivially
370402
/// true.

llvm/include/llvm/Transforms/Utils/UnrollLoop.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ LLVM_ABI bool UnrollRuntimeLoopRemainder(
9797
LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, AssumptionCache *AC,
9898
const TargetTransformInfo *TTI, bool PreserveLCSSA,
9999
unsigned SCEVExpansionBudget, bool RuntimeUnrollMultiExit,
100-
Loop **ResultLoop = nullptr);
100+
Loop **ResultLoop = nullptr,
101+
std::optional<unsigned> OriginalTripCount = std::nullopt,
102+
std::optional<double> OriginalLoopProb = std::nullopt);
101103

102104
LLVM_ABI LoopUnrollResult UnrollAndJamLoop(
103105
Loop *L, unsigned Count, unsigned TripCount, unsigned TripMultiple,

llvm/lib/Transforms/Utils/LoopUnroll.cpp

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
#include "llvm/Transforms/Utils/UnrollLoop.h"
6666
#include "llvm/Transforms/Utils/ValueMapper.h"
6767
#include <assert.h>
68+
#include <cmath>
6869
#include <numeric>
6970
#include <type_traits>
7071
#include <vector>
@@ -501,6 +502,7 @@ llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
501502
const bool MaxOrZero = SE->isBackedgeTakenCountMaxOrZero(L);
502503
std::optional<unsigned> OriginalTripCount =
503504
llvm::getLoopEstimatedTripCount(L);
505+
std::optional<double> OriginalLoopProb = llvm::getLoopProbability(L);
504506

505507
// Effectively "DCE" unrolled iterations that are beyond the max tripcount
506508
// and will never be executed.
@@ -591,11 +593,11 @@ llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
591593
: isEpilogProfitable(L);
592594

593595
if (ULO.Runtime &&
594-
!UnrollRuntimeLoopRemainder(L, ULO.Count, ULO.AllowExpensiveTripCount,
595-
EpilogProfitability, ULO.UnrollRemainder,
596-
ULO.ForgetAllSCEV, LI, SE, DT, AC, TTI,
597-
PreserveLCSSA, ULO.SCEVExpansionBudget,
598-
ULO.RuntimeUnrollMultiExit, RemainderLoop)) {
596+
!UnrollRuntimeLoopRemainder(
597+
L, ULO.Count, ULO.AllowExpensiveTripCount, EpilogProfitability,
598+
ULO.UnrollRemainder, ULO.ForgetAllSCEV, LI, SE, DT, AC, TTI,
599+
PreserveLCSSA, ULO.SCEVExpansionBudget, ULO.RuntimeUnrollMultiExit,
600+
RemainderLoop, OriginalTripCount, OriginalLoopProb)) {
599601
if (ULO.Force)
600602
ULO.Runtime = false;
601603
else {
@@ -1130,13 +1132,13 @@ llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
11301132
LI->erase(L);
11311133
// We shouldn't try to use `L` anymore.
11321134
L = nullptr;
1133-
} else if (OriginalTripCount) {
1135+
} else {
11341136
// Update metadata for the loop's branch weights and estimated trip count:
11351137
// - If ULO.Runtime, UnrollRuntimeLoopRemainder sets the guard branch
11361138
// weights, latch branch weights, and estimated trip count of the
11371139
// remainder loop it creates. It also sets the branch weights for the
11381140
// unrolled loop guard it creates. The branch weights for the unrolled
1139-
// loop latch are adjusted below. FIXME: Actually handle ULO.Runtime.
1141+
// loop latch are adjusted below. FIXME: Handle prologue loops.
11401142
// - Otherwise, if unrolled loop iteration latches become unconditional,
11411143
// branch weights are adjusted above. FIXME: Actually handle such
11421144
// unconditional latches.
@@ -1159,10 +1161,17 @@ llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
11591161
// the unrolled loop as a whole without considering the branch weights for
11601162
// each unrolled iteration's latch within it, we store the new trip count as
11611163
// separate metadata.
1162-
unsigned NewTripCount = *OriginalTripCount / ULO.Count;
1163-
if (!ULO.Runtime && *OriginalTripCount % ULO.Count)
1164-
NewTripCount += 1;
1165-
setLoopEstimatedTripCount(L, NewTripCount);
1164+
if (OriginalLoopProb && ULO.Runtime && EpilogProfitability) {
1165+
// Where p is always the probability of executing at least 1 more
1166+
// iteration, the probability for at least n more iterations is p^n.
1167+
setLoopProbability(L, pow(*OriginalLoopProb, ULO.Count));
1168+
}
1169+
if (OriginalTripCount) {
1170+
unsigned NewTripCount = *OriginalTripCount / ULO.Count;
1171+
if (!ULO.Runtime && *OriginalTripCount % ULO.Count)
1172+
NewTripCount += 1;
1173+
setLoopEstimatedTripCount(L, NewTripCount);
1174+
}
11661175
}
11671176

11681177
// LoopInfo should not be valid, confirm that.

llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp

Lines changed: 76 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include "llvm/Transforms/Utils/LoopUtils.h"
4141
#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
4242
#include "llvm/Transforms/Utils/UnrollLoop.h"
43+
#include <cmath>
4344

4445
using namespace llvm;
4546

@@ -195,6 +196,20 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count,
195196
}
196197
}
197198

199+
/// Assume, due to our position in the remainder loop or its guard, anywhere
200+
/// from 0 to \p N more iterations can possibly execute. Among such cases in
201+
/// the original loop (with loop probability \p OriginalLoopProb), what is the
202+
/// probability of executing at least one more iteration?
203+
static double probOfNextInRemainder(double OriginalLoopProb, unsigned N) {
204+
// Each of these variables holds the original loop's probability that the
205+
// number of iterations it will execute is some m in the specified range.
206+
double ProbOne = OriginalLoopProb; // 1 <= m
207+
double ProbTooMany = pow(ProbOne, N + 1); // N + 1 <= m
208+
double ProbNotTooMany = 1 - ProbTooMany; // 0 <= m <= N
209+
double ProbOneNotTooMany = ProbOne - ProbTooMany; // 1 <= m <= N
210+
return ProbOneNotTooMany / ProbNotTooMany;
211+
}
212+
198213
/// Connect the unrolling epilog code to the original loop.
199214
/// The unrolling epilog code contains code to execute the
200215
/// 'extra' iterations if the run-time trip count modulo the
@@ -221,7 +236,8 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
221236
BasicBlock *EpilogPreHeader, BasicBlock *NewPreHeader,
222237
ValueToValueMapTy &VMap, DominatorTree *DT,
223238
LoopInfo *LI, bool PreserveLCSSA, ScalarEvolution &SE,
224-
unsigned Count, AssumptionCache &AC) {
239+
unsigned Count, AssumptionCache &AC,
240+
std::optional<double> OriginalLoopProb) {
225241
BasicBlock *Latch = L->getLoopLatch();
226242
assert(Latch && "Loop must have a latch");
227243
BasicBlock *EpilogLatch = cast<BasicBlock>(VMap[Latch]);
@@ -332,12 +348,18 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
332348
PreserveLCSSA);
333349
// Add the branch to the exit block (around the epilog loop)
334350
MDNode *BranchWeights = nullptr;
335-
if (hasBranchWeightMD(*Latch->getTerminator())) {
351+
if (!OriginalLoopProb && hasBranchWeightMD(*Latch->getTerminator())) {
336352
// Assume equal distribution in interval [0, Count).
337353
MDBuilder MDB(B.getContext());
338354
BranchWeights = MDB.createBranchWeights(1, Count - 1);
339355
}
340-
B.CreateCondBr(BrLoopExit, EpilogPreHeader, Exit, BranchWeights);
356+
BranchInst *RemainderLoopGuard =
357+
B.CreateCondBr(BrLoopExit, EpilogPreHeader, Exit, BranchWeights);
358+
if (OriginalLoopProb) {
359+
setBranchProbability(RemainderLoopGuard,
360+
probOfNextInRemainder(*OriginalLoopProb, Count - 1),
361+
/*ForFirstTarget=*/true);
362+
}
341363
InsertPt->eraseFromParent();
342364
if (DT) {
343365
auto *NewDom = DT->findNearestCommonDominator(Exit, NewExit);
@@ -357,14 +379,15 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
357379
/// The cloned blocks should be inserted between InsertTop and InsertBot.
358380
/// InsertTop should be new preheader, InsertBot new loop exit.
359381
/// Returns the new cloned loop that is created.
360-
static Loop *
361-
CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,
362-
const bool UnrollRemainder,
363-
BasicBlock *InsertTop,
364-
BasicBlock *InsertBot, BasicBlock *Preheader,
382+
static Loop *CloneLoopBlocks(Loop *L, Value *NewIter,
383+
const bool UseEpilogRemainder,
384+
const bool UnrollRemainder, BasicBlock *InsertTop,
385+
BasicBlock *InsertBot, BasicBlock *Preheader,
365386
std::vector<BasicBlock *> &NewBlocks,
366387
LoopBlocksDFS &LoopBlocks, ValueToValueMapTy &VMap,
367-
DominatorTree *DT, LoopInfo *LI, unsigned Count) {
388+
DominatorTree *DT, LoopInfo *LI, unsigned Count,
389+
std::optional<unsigned> OriginalTripCount,
390+
std::optional<double> OriginalLoopProb) {
368391
StringRef suffix = UseEpilogRemainder ? "epil" : "prol";
369392
BasicBlock *Header = L->getHeader();
370393
BasicBlock *Latch = L->getLoopLatch();
@@ -419,7 +442,8 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,
419442
Builder.CreateAdd(NewIdx, One, NewIdx->getName() + ".next");
420443
Value *IdxCmp = Builder.CreateICmpNE(IdxNext, NewIter, NewIdx->getName() + ".cmp");
421444
MDNode *BranchWeights = nullptr;
422-
if (hasBranchWeightMD(*LatchBR)) {
445+
if (!(OriginalLoopProb && UseEpilogRemainder) &&
446+
hasBranchWeightMD(*LatchBR)) {
423447
uint32_t ExitWeight;
424448
uint32_t BackEdgeWeight;
425449
if (Count >= 3) {
@@ -437,7 +461,25 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,
437461
MDBuilder MDB(Builder.getContext());
438462
BranchWeights = MDB.createBranchWeights(BackEdgeWeight, ExitWeight);
439463
}
440-
Builder.CreateCondBr(IdxCmp, FirstLoopBB, InsertBot, BranchWeights);
464+
BranchInst *RemainderLoopLatch =
465+
Builder.CreateCondBr(IdxCmp, FirstLoopBB, InsertBot, BranchWeights);
466+
if (OriginalLoopProb && UseEpilogRemainder) {
467+
// Compute the total frequency of the original loop body from the
468+
// remainder iterations. Once we've reached them, the first of them
469+
// always executes, so it's frequency and probability are 1.
470+
double FreqRemIters = 1;
471+
if (Count > 2) {
472+
double ProbReaching = 1;
473+
for (unsigned N = Count - 2; N >= 1; --N) {
474+
ProbReaching *= probOfNextInRemainder(*OriginalLoopProb, N);
475+
FreqRemIters += ProbReaching;
476+
}
477+
}
478+
// Solve for the loop probability that would produce that frequency.
479+
// Sum(i=0..inf)(Prob^i) = 1/(1-Prob) = FreqRemIters.
480+
double Prob = 1 - 1 / FreqRemIters;
481+
setBranchProbability(RemainderLoopLatch, Prob, /*ForFirstTarget=*/true);
482+
}
441483
NewIdx->addIncoming(Zero, InsertTop);
442484
NewIdx->addIncoming(IdxNext, NewBB);
443485
LatchBR->eraseFromParent();
@@ -469,6 +511,8 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,
469511

470512
std::optional<MDNode *> NewLoopID = makeFollowupLoopID(
471513
LoopID, {LLVMLoopUnrollFollowupAll, LLVMLoopUnrollFollowupRemainder});
514+
if (OriginalTripCount && UseEpilogRemainder)
515+
setLoopEstimatedTripCount(NewLoop, *OriginalTripCount % Count);
472516
if (NewLoopID) {
473517
NewLoop->setLoopID(*NewLoopID);
474518

@@ -603,7 +647,8 @@ bool llvm::UnrollRuntimeLoopRemainder(
603647
LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, AssumptionCache *AC,
604648
const TargetTransformInfo *TTI, bool PreserveLCSSA,
605649
unsigned SCEVExpansionBudget, bool RuntimeUnrollMultiExit,
606-
Loop **ResultLoop) {
650+
Loop **ResultLoop, std::optional<unsigned> OriginalTripCount,
651+
std::optional<double> OriginalLoopProb) {
607652
LLVM_DEBUG(dbgs() << "Trying runtime unrolling on Loop: \n");
608653
LLVM_DEBUG(L->dump());
609654
LLVM_DEBUG(UseEpilogRemainder ? dbgs() << "Using epilog remainder.\n"
@@ -823,12 +868,23 @@ bool llvm::UnrollRuntimeLoopRemainder(
823868
BasicBlock *UnrollingLoop = UseEpilogRemainder ? NewPreHeader : PrologExit;
824869
// Branch to either remainder (extra iterations) loop or unrolling loop.
825870
MDNode *BranchWeights = nullptr;
826-
if (hasBranchWeightMD(*Latch->getTerminator())) {
871+
if (!(OriginalLoopProb && UseEpilogRemainder) &&
872+
hasBranchWeightMD(*Latch->getTerminator())) {
827873
// Assume loop is nearly always entered.
828874
MDBuilder MDB(B.getContext());
829875
BranchWeights = MDB.createBranchWeights(EpilogHeaderWeights);
830876
}
831-
B.CreateCondBr(BranchVal, RemainderLoop, UnrollingLoop, BranchWeights);
877+
BranchInst *UnrollingLoopGuard =
878+
B.CreateCondBr(BranchVal, RemainderLoop, UnrollingLoop, BranchWeights);
879+
if (OriginalLoopProb && UseEpilogRemainder) {
880+
// The original loop's first iteration always happens. Compute the
881+
// probability of the original loop executing Count-1 iterations after that
882+
// to complete the first iteration of the unrolled loop.
883+
double ProbOne = *OriginalLoopProb;
884+
double ProbRest = pow(ProbOne, Count - 1);
885+
setBranchProbability(UnrollingLoopGuard, ProbRest,
886+
/*ForFirstTarget=*/false);
887+
}
832888
PreHeaderBR->eraseFromParent();
833889
if (DT) {
834890
if (UseEpilogRemainder)
@@ -855,9 +911,10 @@ bool llvm::UnrollRuntimeLoopRemainder(
855911
// iterations. This function adds the appropriate CFG connections.
856912
BasicBlock *InsertBot = UseEpilogRemainder ? LatchExit : PrologExit;
857913
BasicBlock *InsertTop = UseEpilogRemainder ? EpilogPreHeader : PrologPreHeader;
858-
Loop *remainderLoop = CloneLoopBlocks(
859-
L, ModVal, UseEpilogRemainder, UnrollRemainder, InsertTop, InsertBot,
860-
NewPreHeader, NewBlocks, LoopBlocks, VMap, DT, LI, Count);
914+
Loop *remainderLoop =
915+
CloneLoopBlocks(L, ModVal, UseEpilogRemainder, UnrollRemainder, InsertTop,
916+
InsertBot, NewPreHeader, NewBlocks, LoopBlocks, VMap, DT,
917+
LI, Count, OriginalTripCount, OriginalLoopProb);
861918

862919
// Insert the cloned blocks into the function.
863920
F->splice(InsertBot->getIterator(), F, NewBlocks[0]->getIterator(), F->end());
@@ -956,7 +1013,8 @@ bool llvm::UnrollRuntimeLoopRemainder(
9561013
// Connect the epilog code to the original loop and update the
9571014
// PHI functions.
9581015
ConnectEpilog(L, ModVal, NewExit, LatchExit, PreHeader, EpilogPreHeader,
959-
NewPreHeader, VMap, DT, LI, PreserveLCSSA, *SE, Count, *AC);
1016+
NewPreHeader, VMap, DT, LI, PreserveLCSSA, *SE, Count, *AC,
1017+
OriginalLoopProb);
9601018

9611019
// Update counter in loop for unrolling.
9621020
// Use an incrementing IV. Pre-incr/post-incr is backedge/trip count.

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
4747
#include "llvm/Transforms/Utils/Local.h"
4848
#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
49+
#include <cmath>
4950

5051
using namespace llvm;
5152
using namespace llvm::PatternMatch;
@@ -972,6 +973,53 @@ bool llvm::setLoopEstimatedTripCount(
972973
return true;
973974
}
974975

976+
std::optional<double> llvm::getLoopProbability(Loop *L) {
977+
BranchInst *LatchBranch = getExpectedExitLoopLatchBranch(L);
978+
if (!LatchBranch)
979+
return std::nullopt;
980+
bool FirstTargetIsLoop = LatchBranch->getSuccessor(0) == L->getHeader();
981+
return getBranchProbability(LatchBranch, FirstTargetIsLoop);
982+
}
983+
984+
bool llvm::setLoopProbability(Loop *L, double P) {
985+
BranchInst *LatchBranch = getExpectedExitLoopLatchBranch(L);
986+
if (!LatchBranch)
987+
return false;
988+
bool FirstTargetIsLoop = LatchBranch->getSuccessor(0) == L->getHeader();
989+
return setBranchProbability(LatchBranch, P, FirstTargetIsLoop);
990+
}
991+
992+
std::optional<double> llvm::getBranchProbability(BranchInst *B,
993+
bool ForFirstTarget) {
994+
if (B->getNumSuccessors() != 2)
995+
return std::nullopt;
996+
uint64_t Weight0, Weight1;
997+
if (!extractBranchWeights(*B, Weight0, Weight1))
998+
return std::nullopt;
999+
if (!ForFirstTarget)
1000+
std::swap(Weight0, Weight1);
1001+
return double(Weight0) / (double(Weight0) + double(Weight1));
1002+
}
1003+
1004+
bool llvm::setBranchProbability(BranchInst *B, double P, bool ForFirstTarget) {
1005+
if (B->getNumSuccessors() != 2)
1006+
return false;
1007+
1008+
// Sum should be some large number so that the weights accurately encode P,
1009+
// but it should not be so large that some branch weights will print as
1010+
// negative in LLVM IR as that makes LLVM tests harder to maintain.
1011+
const uint64_t Sum = 1000000000;
1012+
uint64_t Weight0 = round(P * Sum);
1013+
uint64_t Weight1 = round((1 - P) * Sum);
1014+
if (!ForFirstTarget)
1015+
std::swap(Weight0, Weight1);
1016+
1017+
MDBuilder MDB(B->getContext());
1018+
B->setMetadata(LLVMContext::MD_prof,
1019+
MDB.createBranchWeights(Weight0, Weight1));
1020+
return true;
1021+
}
1022+
9751023
bool llvm::hasIterationCountInvariantInParent(Loop *InnerLoop,
9761024
ScalarEvolution &SE) {
9771025
Loop *OuterL = InnerLoop->getParentLoop();

0 commit comments

Comments
 (0)