4040#include " llvm/Transforms/Utils/LoopUtils.h"
4141#include " llvm/Transforms/Utils/ScalarEvolutionExpander.h"
4242#include " llvm/Transforms/Utils/UnrollLoop.h"
43+ #include < cmath>
4344
4445using namespace llvm ;
4546
@@ -195,6 +196,20 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count,
195196 }
196197}
197198
199+ // / Assume, due to our position in the remainder loop or its guard, anywhere
200+ // / from 0 to \p N more iterations can possibly execute. Among such cases in
201+ // / the original loop (with loop probability \p OriginalLoopProb), what is the
202+ // / probability of executing at least one more iteration?
203+ static double probOfNextInRemainder (double OriginalLoopProb, unsigned N) {
204+ // Each of these variables holds the original loop's probability that the
205+ // number of iterations it will execute is some m in the specified range.
206+ double ProbOne = OriginalLoopProb; // 1 <= m
207+ double ProbTooMany = pow (ProbOne, N + 1 ); // N + 1 <= m
208+ double ProbNotTooMany = 1 - ProbTooMany; // 0 <= m <= N
209+ double ProbOneNotTooMany = ProbOne - ProbTooMany; // 1 <= m <= N
210+ return ProbOneNotTooMany / ProbNotTooMany;
211+ }
212+
198213// / Connect the unrolling epilog code to the original loop.
199214// / The unrolling epilog code contains code to execute the
200215// / 'extra' iterations if the run-time trip count modulo the
@@ -221,7 +236,8 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
221236 BasicBlock *EpilogPreHeader, BasicBlock *NewPreHeader,
222237 ValueToValueMapTy &VMap, DominatorTree *DT,
223238 LoopInfo *LI, bool PreserveLCSSA, ScalarEvolution &SE,
224- unsigned Count, AssumptionCache &AC) {
239+ unsigned Count, AssumptionCache &AC,
240+ std::optional<double > OriginalLoopProb) {
225241 BasicBlock *Latch = L->getLoopLatch ();
226242 assert (Latch && " Loop must have a latch" );
227243 BasicBlock *EpilogLatch = cast<BasicBlock>(VMap[Latch]);
@@ -332,12 +348,18 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
332348 PreserveLCSSA);
333349 // Add the branch to the exit block (around the epilog loop)
334350 MDNode *BranchWeights = nullptr ;
335- if (hasBranchWeightMD (*Latch->getTerminator ())) {
351+ if (!OriginalLoopProb && hasBranchWeightMD (*Latch->getTerminator ())) {
336352 // Assume equal distribution in interval [0, Count).
337353 MDBuilder MDB (B.getContext ());
338354 BranchWeights = MDB.createBranchWeights (1 , Count - 1 );
339355 }
340- B.CreateCondBr (BrLoopExit, EpilogPreHeader, Exit, BranchWeights);
356+ BranchInst *RemainderLoopGuard =
357+ B.CreateCondBr (BrLoopExit, EpilogPreHeader, Exit, BranchWeights);
358+ if (OriginalLoopProb) {
359+ setBranchProbability (RemainderLoopGuard,
360+ probOfNextInRemainder (*OriginalLoopProb, Count - 1 ),
361+ /* ForFirstTarget=*/ true );
362+ }
341363 InsertPt->eraseFromParent ();
342364 if (DT) {
343365 auto *NewDom = DT->findNearestCommonDominator (Exit, NewExit);
@@ -357,14 +379,15 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
357379// / The cloned blocks should be inserted between InsertTop and InsertBot.
358380// / InsertTop should be new preheader, InsertBot new loop exit.
359381// / Returns the new cloned loop that is created.
360- static Loop *
361- CloneLoopBlocks (Loop *L, Value *NewIter, const bool UseEpilogRemainder,
362- const bool UnrollRemainder,
363- BasicBlock *InsertTop,
364- BasicBlock *InsertBot, BasicBlock *Preheader,
382+ static Loop *CloneLoopBlocks (Loop *L, Value *NewIter,
383+ const bool UseEpilogRemainder,
384+ const bool UnrollRemainder, BasicBlock *InsertTop,
385+ BasicBlock *InsertBot, BasicBlock *Preheader,
365386 std::vector<BasicBlock *> &NewBlocks,
366387 LoopBlocksDFS &LoopBlocks, ValueToValueMapTy &VMap,
367- DominatorTree *DT, LoopInfo *LI, unsigned Count) {
388+ DominatorTree *DT, LoopInfo *LI, unsigned Count,
389+ std::optional<unsigned > OriginalTripCount,
390+ std::optional<double > OriginalLoopProb) {
368391 StringRef suffix = UseEpilogRemainder ? " epil" : " prol" ;
369392 BasicBlock *Header = L->getHeader ();
370393 BasicBlock *Latch = L->getLoopLatch ();
@@ -419,7 +442,8 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,
419442 Builder.CreateAdd (NewIdx, One, NewIdx->getName () + " .next" );
420443 Value *IdxCmp = Builder.CreateICmpNE (IdxNext, NewIter, NewIdx->getName () + " .cmp" );
421444 MDNode *BranchWeights = nullptr ;
422- if (hasBranchWeightMD (*LatchBR)) {
445+ if (!(OriginalLoopProb && UseEpilogRemainder) &&
446+ hasBranchWeightMD (*LatchBR)) {
423447 uint32_t ExitWeight;
424448 uint32_t BackEdgeWeight;
425449 if (Count >= 3 ) {
@@ -437,7 +461,25 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,
437461 MDBuilder MDB (Builder.getContext ());
438462 BranchWeights = MDB.createBranchWeights (BackEdgeWeight, ExitWeight);
439463 }
440- Builder.CreateCondBr (IdxCmp, FirstLoopBB, InsertBot, BranchWeights);
464+ BranchInst *RemainderLoopLatch =
465+ Builder.CreateCondBr (IdxCmp, FirstLoopBB, InsertBot, BranchWeights);
466+ if (OriginalLoopProb && UseEpilogRemainder) {
467+ // Compute the total frequency of the original loop body from the
468+ // remainder iterations. Once we've reached them, the first of them
469+ // always executes, so it's frequency and probability are 1.
470+ double FreqRemIters = 1 ;
471+ if (Count > 2 ) {
472+ double ProbReaching = 1 ;
473+ for (unsigned N = Count - 2 ; N >= 1 ; --N) {
474+ ProbReaching *= probOfNextInRemainder (*OriginalLoopProb, N);
475+ FreqRemIters += ProbReaching;
476+ }
477+ }
478+ // Solve for the loop probability that would produce that frequency.
479+ // Sum(i=0..inf)(Prob^i) = 1/(1-Prob) = FreqRemIters.
480+ double Prob = 1 - 1 / FreqRemIters;
481+ setBranchProbability (RemainderLoopLatch, Prob, /* ForFirstTarget=*/ true );
482+ }
441483 NewIdx->addIncoming (Zero, InsertTop);
442484 NewIdx->addIncoming (IdxNext, NewBB);
443485 LatchBR->eraseFromParent ();
@@ -469,6 +511,8 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,
469511
470512 std::optional<MDNode *> NewLoopID = makeFollowupLoopID (
471513 LoopID, {LLVMLoopUnrollFollowupAll, LLVMLoopUnrollFollowupRemainder});
514+ if (OriginalTripCount && UseEpilogRemainder)
515+ setLoopEstimatedTripCount (NewLoop, *OriginalTripCount % Count);
472516 if (NewLoopID) {
473517 NewLoop->setLoopID (*NewLoopID);
474518
@@ -603,7 +647,8 @@ bool llvm::UnrollRuntimeLoopRemainder(
603647 LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, AssumptionCache *AC,
604648 const TargetTransformInfo *TTI, bool PreserveLCSSA,
605649 unsigned SCEVExpansionBudget, bool RuntimeUnrollMultiExit,
606- Loop **ResultLoop) {
650+ Loop **ResultLoop, std::optional<unsigned > OriginalTripCount,
651+ std::optional<double > OriginalLoopProb) {
607652 LLVM_DEBUG (dbgs () << " Trying runtime unrolling on Loop: \n " );
608653 LLVM_DEBUG (L->dump ());
609654 LLVM_DEBUG (UseEpilogRemainder ? dbgs () << " Using epilog remainder.\n "
@@ -823,12 +868,23 @@ bool llvm::UnrollRuntimeLoopRemainder(
823868 BasicBlock *UnrollingLoop = UseEpilogRemainder ? NewPreHeader : PrologExit;
824869 // Branch to either remainder (extra iterations) loop or unrolling loop.
825870 MDNode *BranchWeights = nullptr ;
826- if (hasBranchWeightMD (*Latch->getTerminator ())) {
871+ if (!(OriginalLoopProb && UseEpilogRemainder) &&
872+ hasBranchWeightMD (*Latch->getTerminator ())) {
827873 // Assume loop is nearly always entered.
828874 MDBuilder MDB (B.getContext ());
829875 BranchWeights = MDB.createBranchWeights (EpilogHeaderWeights);
830876 }
831- B.CreateCondBr (BranchVal, RemainderLoop, UnrollingLoop, BranchWeights);
877+ BranchInst *UnrollingLoopGuard =
878+ B.CreateCondBr (BranchVal, RemainderLoop, UnrollingLoop, BranchWeights);
879+ if (OriginalLoopProb && UseEpilogRemainder) {
880+ // The original loop's first iteration always happens. Compute the
881+ // probability of the original loop executing Count-1 iterations after that
882+ // to complete the first iteration of the unrolled loop.
883+ double ProbOne = *OriginalLoopProb;
884+ double ProbRest = pow (ProbOne, Count - 1 );
885+ setBranchProbability (UnrollingLoopGuard, ProbRest,
886+ /* ForFirstTarget=*/ false );
887+ }
832888 PreHeaderBR->eraseFromParent ();
833889 if (DT) {
834890 if (UseEpilogRemainder)
@@ -855,9 +911,10 @@ bool llvm::UnrollRuntimeLoopRemainder(
855911 // iterations. This function adds the appropriate CFG connections.
856912 BasicBlock *InsertBot = UseEpilogRemainder ? LatchExit : PrologExit;
857913 BasicBlock *InsertTop = UseEpilogRemainder ? EpilogPreHeader : PrologPreHeader;
858- Loop *remainderLoop = CloneLoopBlocks (
859- L, ModVal, UseEpilogRemainder, UnrollRemainder, InsertTop, InsertBot,
860- NewPreHeader, NewBlocks, LoopBlocks, VMap, DT, LI, Count);
914+ Loop *remainderLoop =
915+ CloneLoopBlocks (L, ModVal, UseEpilogRemainder, UnrollRemainder, InsertTop,
916+ InsertBot, NewPreHeader, NewBlocks, LoopBlocks, VMap, DT,
917+ LI, Count, OriginalTripCount, OriginalLoopProb);
861918
862919 // Insert the cloned blocks into the function.
863920 F->splice (InsertBot->getIterator (), F, NewBlocks[0 ]->getIterator (), F->end ());
@@ -956,7 +1013,8 @@ bool llvm::UnrollRuntimeLoopRemainder(
9561013 // Connect the epilog code to the original loop and update the
9571014 // PHI functions.
9581015 ConnectEpilog (L, ModVal, NewExit, LatchExit, PreHeader, EpilogPreHeader,
959- NewPreHeader, VMap, DT, LI, PreserveLCSSA, *SE, Count, *AC);
1016+ NewPreHeader, VMap, DT, LI, PreserveLCSSA, *SE, Count, *AC,
1017+ OriginalLoopProb);
9601018
9611019 // Update counter in loop for unrolling.
9621020 // Use an incrementing IV. Pre-incr/post-incr is backedge/trip count.
0 commit comments