4040#include  " llvm/Transforms/Utils/LoopUtils.h" 
4141#include  " llvm/Transforms/Utils/ScalarEvolutionExpander.h" 
4242#include  " llvm/Transforms/Utils/UnrollLoop.h" 
43+ #include  < cmath> 
4344
4445using  namespace  llvm ; 
4546
@@ -195,6 +196,21 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count,
195196  }
196197}
197198
199+ // / Assume, due to our position in the remainder loop or its guard, anywhere
200+ // / from 0 to \p N more iterations can possibly execute.  Among such cases in
201+ // / the original loop (with loop probability \p OriginalLoopProb), what is the
202+ // / probability of executing at least one more iteration?
203+ static  BranchProbability
204+ probOfNextInRemainder (BranchProbability OriginalLoopProb, unsigned  N) {
205+   //  Each of these variables holds the original loop's probability that the
206+   //  number of iterations it will execute is some m in the specified range.
207+   BranchProbability ProbOne = OriginalLoopProb;                //  1 <= m
208+   BranchProbability ProbTooMany = ProbOne.pow (N + 1 );          //  N + 1 <= m
209+   BranchProbability ProbNotTooMany = ProbTooMany.getCompl ();   //  0 <= m <= N
210+   BranchProbability ProbOneNotTooMany = ProbOne - ProbTooMany; //  1 <= m <= N
211+   return  ProbOneNotTooMany / ProbNotTooMany;
212+ }
213+ 
198214// / Connect the unrolling epilog code to the original loop.
199215// / The unrolling epilog code contains code to execute the
200216// / 'extra' iterations if the run-time trip count modulo the
@@ -221,7 +237,8 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
221237                          BasicBlock *EpilogPreHeader, BasicBlock *NewPreHeader,
222238                          ValueToValueMapTy &VMap, DominatorTree *DT,
223239                          LoopInfo *LI, bool  PreserveLCSSA, ScalarEvolution &SE,
224-                           unsigned  Count, AssumptionCache &AC) {
240+                           unsigned  Count, AssumptionCache &AC,
241+                           BranchProbability OriginalLoopProb) {
225242  BasicBlock *Latch = L->getLoopLatch ();
226243  assert (Latch && " Loop must have a latch"  );
227244  BasicBlock *EpilogLatch = cast<BasicBlock>(VMap[Latch]);
@@ -332,12 +349,19 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
332349                         PreserveLCSSA);
333350  //  Add the branch to the exit block (around the epilog loop)
334351  MDNode *BranchWeights = nullptr ;
335-   if  (hasBranchWeightMD (*Latch->getTerminator ())) {
352+   if  (OriginalLoopProb.isUnknown () &&
353+       hasBranchWeightMD (*Latch->getTerminator ())) {
336354    //  Assume equal distribution in interval [0, Count).
337355    MDBuilder MDB (B.getContext ());
338356    BranchWeights = MDB.createBranchWeights (1 , Count - 1 );
339357  }
340-   B.CreateCondBr (BrLoopExit, EpilogPreHeader, Exit, BranchWeights);
358+   BranchInst *RemainderLoopGuard =
359+       B.CreateCondBr (BrLoopExit, EpilogPreHeader, Exit, BranchWeights);
360+   if  (!OriginalLoopProb.isUnknown ()) {
361+     setBranchProbability (RemainderLoopGuard,
362+                          probOfNextInRemainder (OriginalLoopProb, Count - 1 ),
363+                          /* ForFirstTarget=*/ true );
364+   }
341365  InsertPt->eraseFromParent ();
342366  if  (DT) {
343367    auto  *NewDom = DT->findNearestCommonDominator (Exit, NewExit);
@@ -357,14 +381,15 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
357381// / The cloned blocks should be inserted between InsertTop and InsertBot.
358382// / InsertTop should be new preheader, InsertBot new loop exit.
359383// / Returns the new cloned loop that is created.
360- static  Loop *
361- CloneLoopBlocks (Loop *L, Value *NewIter, const  bool  UseEpilogRemainder,
362-                 const  bool  UnrollRemainder,
363-                 BasicBlock *InsertTop,
364-                 BasicBlock *InsertBot, BasicBlock *Preheader,
384+ static  Loop *CloneLoopBlocks (Loop *L, Value *NewIter,
385+                              const  bool  UseEpilogRemainder,
386+                              const  bool  UnrollRemainder, BasicBlock *InsertTop,
387+                              BasicBlock *InsertBot, BasicBlock *Preheader,
365388                             std::vector<BasicBlock *> &NewBlocks,
366389                             LoopBlocksDFS &LoopBlocks, ValueToValueMapTy &VMap,
367-                              DominatorTree *DT, LoopInfo *LI, unsigned  Count) {
390+                              DominatorTree *DT, LoopInfo *LI, unsigned  Count,
391+                              std::optional<unsigned > OriginalTripCount,
392+                              BranchProbability OriginalLoopProb) {
368393  StringRef suffix = UseEpilogRemainder ? " epil"   : " prol"  ;
369394  BasicBlock *Header = L->getHeader ();
370395  BasicBlock *Latch = L->getLoopLatch ();
@@ -419,7 +444,8 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,
419444          Builder.CreateAdd (NewIdx, One, NewIdx->getName () + " .next"  );
420445      Value *IdxCmp = Builder.CreateICmpNE (IdxNext, NewIter, NewIdx->getName () + " .cmp"  );
421446      MDNode *BranchWeights = nullptr ;
422-       if  (hasBranchWeightMD (*LatchBR)) {
447+       if  ((OriginalLoopProb.isUnknown () || !UseEpilogRemainder) &&
448+           hasBranchWeightMD (*LatchBR)) {
423449        uint32_t  ExitWeight;
424450        uint32_t  BackEdgeWeight;
425451        if  (Count >= 3 ) {
@@ -437,7 +463,29 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,
437463        MDBuilder MDB (Builder.getContext ());
438464        BranchWeights = MDB.createBranchWeights (BackEdgeWeight, ExitWeight);
439465      }
440-       Builder.CreateCondBr (IdxCmp, FirstLoopBB, InsertBot, BranchWeights);
466+       BranchInst *RemainderLoopLatch =
467+           Builder.CreateCondBr (IdxCmp, FirstLoopBB, InsertBot, BranchWeights);
468+       if  (!OriginalLoopProb.isUnknown () && UseEpilogRemainder) {
469+         //  Compute the total frequency of the original loop body from the
470+         //  remainder iterations.  Once we've reached them, the first of them
471+         //  always executes, so its frequency and probability are 1.
472+         double  FreqRemIters = 1 ;
473+         if  (Count > 2 ) {
474+           BranchProbability ProbReaching = BranchProbability::getOne ();
475+           for  (unsigned  N = Count - 2 ; N >= 1 ; --N) {
476+             ProbReaching *= probOfNextInRemainder (OriginalLoopProb, N);
477+             FreqRemIters += double (ProbReaching.getNumerator ()) /
478+                             ProbReaching.getDenominator ();
479+           }
480+         }
481+         //  Solve for the loop probability that would produce that frequency.
482+         //  Sum(i=0..inf)(Prob^i) = 1/(1-Prob) = FreqRemIters.
483+         double  ProbDouble = 1  - 1  / FreqRemIters;
484+         BranchProbability Prob = BranchProbability::getBranchProbability (
485+             std::round (ProbDouble * BranchProbability::getDenominator ()),
486+             BranchProbability::getDenominator ());
487+         setBranchProbability (RemainderLoopLatch, Prob, /* ForFirstTarget=*/ true );
488+       }
441489      NewIdx->addIncoming (Zero, InsertTop);
442490      NewIdx->addIncoming (IdxNext, NewBB);
443491      LatchBR->eraseFromParent ();
@@ -461,6 +509,9 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,
461509  Loop *NewLoop = NewLoops[L];
462510  assert (NewLoop && " L should have been cloned"  );
463511
512+   if  (OriginalTripCount && UseEpilogRemainder)
513+     setLoopEstimatedTripCount (NewLoop, *OriginalTripCount % Count);
514+ 
464515  //  Add unroll disable metadata to disable future unrolling for this loop.
465516  if  (!UnrollRemainder)
466517    NewLoop->setLoopAlreadyUnrolled ();
@@ -588,7 +639,8 @@ bool llvm::UnrollRuntimeLoopRemainder(
588639    LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, AssumptionCache *AC,
589640    const  TargetTransformInfo *TTI, bool  PreserveLCSSA,
590641    unsigned  SCEVExpansionBudget, bool  RuntimeUnrollMultiExit,
591-     Loop **ResultLoop) {
642+     Loop **ResultLoop, std::optional<unsigned > OriginalTripCount,
643+     BranchProbability OriginalLoopProb) {
592644  LLVM_DEBUG (dbgs () << " Trying runtime unrolling on Loop: \n "  );
593645  LLVM_DEBUG (L->dump ());
594646  LLVM_DEBUG (UseEpilogRemainder ? dbgs () << " Using epilog remainder.\n " 
@@ -808,12 +860,23 @@ bool llvm::UnrollRuntimeLoopRemainder(
808860  BasicBlock *UnrollingLoop = UseEpilogRemainder ? NewPreHeader : PrologExit;
809861  //  Branch to either remainder (extra iterations) loop or unrolling loop.
810862  MDNode *BranchWeights = nullptr ;
811-   if  (hasBranchWeightMD (*Latch->getTerminator ())) {
863+   if  ((OriginalLoopProb.isUnknown () || !UseEpilogRemainder) &&
864+       hasBranchWeightMD (*Latch->getTerminator ())) {
812865    //  Assume loop is nearly always entered.
813866    MDBuilder MDB (B.getContext ());
814867    BranchWeights = MDB.createBranchWeights (EpilogHeaderWeights);
815868  }
816-   B.CreateCondBr (BranchVal, RemainderLoop, UnrollingLoop, BranchWeights);
869+   BranchInst *UnrollingLoopGuard =
870+       B.CreateCondBr (BranchVal, RemainderLoop, UnrollingLoop, BranchWeights);
871+   if  (!OriginalLoopProb.isUnknown () && UseEpilogRemainder) {
872+     //  The original loop's first iteration always happens.  Compute the
873+     //  probability of the original loop executing Count-1 iterations after that
874+     //  to complete the first iteration of the unrolled loop.
875+     BranchProbability ProbOne = OriginalLoopProb;
876+     BranchProbability ProbRest = ProbOne.pow (Count - 1 );
877+     setBranchProbability (UnrollingLoopGuard, ProbRest,
878+                          /* ForFirstTarget=*/ false );
879+   }
817880  PreHeaderBR->eraseFromParent ();
818881  if  (DT) {
819882    if  (UseEpilogRemainder)
@@ -840,9 +903,10 @@ bool llvm::UnrollRuntimeLoopRemainder(
840903  //  iterations. This function adds the appropriate CFG connections.
841904  BasicBlock *InsertBot = UseEpilogRemainder ? LatchExit : PrologExit;
842905  BasicBlock *InsertTop = UseEpilogRemainder ? EpilogPreHeader : PrologPreHeader;
843-   Loop *remainderLoop = CloneLoopBlocks (
844-       L, ModVal, UseEpilogRemainder, UnrollRemainder, InsertTop, InsertBot,
845-       NewPreHeader, NewBlocks, LoopBlocks, VMap, DT, LI, Count);
906+   Loop *remainderLoop =
907+       CloneLoopBlocks (L, ModVal, UseEpilogRemainder, UnrollRemainder, InsertTop,
908+                       InsertBot, NewPreHeader, NewBlocks, LoopBlocks, VMap, DT,
909+                       LI, Count, OriginalTripCount, OriginalLoopProb);
846910
847911  //  Insert the cloned blocks into the function.
848912  F->splice (InsertBot->getIterator (), F, NewBlocks[0 ]->getIterator (), F->end ());
@@ -941,7 +1005,8 @@ bool llvm::UnrollRuntimeLoopRemainder(
9411005    //  Connect the epilog code to the original loop and update the
9421006    //  PHI functions.
9431007    ConnectEpilog (L, ModVal, NewExit, LatchExit, PreHeader, EpilogPreHeader,
944-                   NewPreHeader, VMap, DT, LI, PreserveLCSSA, *SE, Count, *AC);
1008+                   NewPreHeader, VMap, DT, LI, PreserveLCSSA, *SE, Count, *AC,
1009+                   OriginalLoopProb);
9451010
9461011    //  Update counter in loop for unrolling.
9471012    //  Use an incrementing IV.  Pre-incr/post-incr is backedge/trip count.
0 commit comments