40
40
#include " llvm/Transforms/Utils/LoopUtils.h"
41
41
#include " llvm/Transforms/Utils/ScalarEvolutionExpander.h"
42
42
#include " llvm/Transforms/Utils/UnrollLoop.h"
43
+ #include < cmath>
43
44
44
45
using namespace llvm ;
45
46
@@ -195,6 +196,20 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count,
195
196
}
196
197
}
197
198
199
+ // / Assume, due to our position in the remainder loop or its guard, anywhere
200
+ // / from 0 to \p N more iterations can possibly execute. Among such cases in
201
+ // / the original loop (with loop probability \p OriginalLoopProb), what is the
202
+ // / probability of executing at least one more iteration?
203
+ static double probOfNextInRemainder (double OriginalLoopProb, unsigned N) {
204
+ // Each of these variables holds the original loop's probability that the
205
+ // number of iterations it will execute is some m in the specified range.
206
+ double ProbOne = OriginalLoopProb; // 1 <= m
207
+ double ProbTooMany = pow (ProbOne, N + 1 ); // N + 1 <= m
208
+ double ProbNotTooMany = 1 - ProbTooMany; // 0 <= m <= N
209
+ double ProbOneNotTooMany = ProbOne - ProbTooMany; // 1 <= m <= N
210
+ return ProbOneNotTooMany / ProbNotTooMany;
211
+ }
212
+
198
213
// / Connect the unrolling epilog code to the original loop.
199
214
// / The unrolling epilog code contains code to execute the
200
215
// / 'extra' iterations if the run-time trip count modulo the
@@ -221,7 +236,8 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
221
236
BasicBlock *EpilogPreHeader, BasicBlock *NewPreHeader,
222
237
ValueToValueMapTy &VMap, DominatorTree *DT,
223
238
LoopInfo *LI, bool PreserveLCSSA, ScalarEvolution &SE,
224
- unsigned Count, AssumptionCache &AC) {
239
+ unsigned Count, AssumptionCache &AC,
240
+ std::optional<double > OriginalLoopProb) {
225
241
BasicBlock *Latch = L->getLoopLatch ();
226
242
assert (Latch && " Loop must have a latch" );
227
243
BasicBlock *EpilogLatch = cast<BasicBlock>(VMap[Latch]);
@@ -332,12 +348,18 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
332
348
PreserveLCSSA);
333
349
// Add the branch to the exit block (around the epilog loop)
334
350
MDNode *BranchWeights = nullptr ;
335
- if (hasBranchWeightMD (*Latch->getTerminator ())) {
351
+ if (!OriginalLoopProb && hasBranchWeightMD (*Latch->getTerminator ())) {
336
352
// Assume equal distribution in interval [0, Count).
337
353
MDBuilder MDB (B.getContext ());
338
354
BranchWeights = MDB.createBranchWeights (1 , Count - 1 );
339
355
}
340
- B.CreateCondBr (BrLoopExit, EpilogPreHeader, Exit, BranchWeights);
356
+ BranchInst *RemainderLoopGuard =
357
+ B.CreateCondBr (BrLoopExit, EpilogPreHeader, Exit, BranchWeights);
358
+ if (OriginalLoopProb) {
359
+ setBranchProbability (RemainderLoopGuard,
360
+ probOfNextInRemainder (*OriginalLoopProb, Count - 1 ),
361
+ /* ForFirstTarget=*/ true );
362
+ }
341
363
InsertPt->eraseFromParent ();
342
364
if (DT) {
343
365
auto *NewDom = DT->findNearestCommonDominator (Exit, NewExit);
@@ -357,14 +379,15 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit,
357
379
// / The cloned blocks should be inserted between InsertTop and InsertBot.
358
380
// / InsertTop should be new preheader, InsertBot new loop exit.
359
381
// / Returns the new cloned loop that is created.
360
- static Loop *
361
- CloneLoopBlocks (Loop *L, Value *NewIter, const bool UseEpilogRemainder,
362
- const bool UnrollRemainder,
363
- BasicBlock *InsertTop,
364
- BasicBlock *InsertBot, BasicBlock *Preheader,
382
+ static Loop *CloneLoopBlocks (Loop *L, Value *NewIter,
383
+ const bool UseEpilogRemainder,
384
+ const bool UnrollRemainder, BasicBlock *InsertTop,
385
+ BasicBlock *InsertBot, BasicBlock *Preheader,
365
386
std::vector<BasicBlock *> &NewBlocks,
366
387
LoopBlocksDFS &LoopBlocks, ValueToValueMapTy &VMap,
367
- DominatorTree *DT, LoopInfo *LI, unsigned Count) {
388
+ DominatorTree *DT, LoopInfo *LI, unsigned Count,
389
+ std::optional<unsigned > OriginalTripCount,
390
+ std::optional<double > OriginalLoopProb) {
368
391
StringRef suffix = UseEpilogRemainder ? " epil" : " prol" ;
369
392
BasicBlock *Header = L->getHeader ();
370
393
BasicBlock *Latch = L->getLoopLatch ();
@@ -419,7 +442,8 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,
419
442
Builder.CreateAdd (NewIdx, One, NewIdx->getName () + " .next" );
420
443
Value *IdxCmp = Builder.CreateICmpNE (IdxNext, NewIter, NewIdx->getName () + " .cmp" );
421
444
MDNode *BranchWeights = nullptr ;
422
- if (hasBranchWeightMD (*LatchBR)) {
445
+ if (!(OriginalLoopProb && UseEpilogRemainder) &&
446
+ hasBranchWeightMD (*LatchBR)) {
423
447
uint32_t ExitWeight;
424
448
uint32_t BackEdgeWeight;
425
449
if (Count >= 3 ) {
@@ -437,7 +461,25 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,
437
461
MDBuilder MDB (Builder.getContext ());
438
462
BranchWeights = MDB.createBranchWeights (BackEdgeWeight, ExitWeight);
439
463
}
440
- Builder.CreateCondBr (IdxCmp, FirstLoopBB, InsertBot, BranchWeights);
464
+ BranchInst *RemainderLoopLatch =
465
+ Builder.CreateCondBr (IdxCmp, FirstLoopBB, InsertBot, BranchWeights);
466
+ if (OriginalLoopProb && UseEpilogRemainder) {
467
+ // Compute the total frequency of the original loop body from the
468
+ // remainder iterations. Once we've reached them, the first of them
469
+ // always executes, so it's frequency and probability are 1.
470
+ double FreqRemIters = 1 ;
471
+ if (Count > 2 ) {
472
+ double ProbReaching = 1 ;
473
+ for (unsigned N = Count - 2 ; N >= 1 ; --N) {
474
+ ProbReaching *= probOfNextInRemainder (*OriginalLoopProb, N);
475
+ FreqRemIters += ProbReaching;
476
+ }
477
+ }
478
+ // Solve for the loop probability that would produce that frequency.
479
+ // Sum(i=0..inf)(Prob^i) = 1/(1-Prob) = FreqRemIters.
480
+ double Prob = 1 - 1 / FreqRemIters;
481
+ setBranchProbability (RemainderLoopLatch, Prob, /* ForFirstTarget=*/ true );
482
+ }
441
483
NewIdx->addIncoming (Zero, InsertTop);
442
484
NewIdx->addIncoming (IdxNext, NewBB);
443
485
LatchBR->eraseFromParent ();
@@ -469,6 +511,8 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder,
469
511
470
512
std::optional<MDNode *> NewLoopID = makeFollowupLoopID (
471
513
LoopID, {LLVMLoopUnrollFollowupAll, LLVMLoopUnrollFollowupRemainder});
514
+ if (OriginalTripCount && UseEpilogRemainder)
515
+ setLoopEstimatedTripCount (NewLoop, *OriginalTripCount % Count);
472
516
if (NewLoopID) {
473
517
NewLoop->setLoopID (*NewLoopID);
474
518
@@ -603,7 +647,8 @@ bool llvm::UnrollRuntimeLoopRemainder(
603
647
LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, AssumptionCache *AC,
604
648
const TargetTransformInfo *TTI, bool PreserveLCSSA,
605
649
unsigned SCEVExpansionBudget, bool RuntimeUnrollMultiExit,
606
- Loop **ResultLoop) {
650
+ Loop **ResultLoop, std::optional<unsigned > OriginalTripCount,
651
+ std::optional<double > OriginalLoopProb) {
607
652
LLVM_DEBUG (dbgs () << " Trying runtime unrolling on Loop: \n " );
608
653
LLVM_DEBUG (L->dump ());
609
654
LLVM_DEBUG (UseEpilogRemainder ? dbgs () << " Using epilog remainder.\n "
@@ -823,12 +868,23 @@ bool llvm::UnrollRuntimeLoopRemainder(
823
868
BasicBlock *UnrollingLoop = UseEpilogRemainder ? NewPreHeader : PrologExit;
824
869
// Branch to either remainder (extra iterations) loop or unrolling loop.
825
870
MDNode *BranchWeights = nullptr ;
826
- if (hasBranchWeightMD (*Latch->getTerminator ())) {
871
+ if (!(OriginalLoopProb && UseEpilogRemainder) &&
872
+ hasBranchWeightMD (*Latch->getTerminator ())) {
827
873
// Assume loop is nearly always entered.
828
874
MDBuilder MDB (B.getContext ());
829
875
BranchWeights = MDB.createBranchWeights (EpilogHeaderWeights);
830
876
}
831
- B.CreateCondBr (BranchVal, RemainderLoop, UnrollingLoop, BranchWeights);
877
+ BranchInst *UnrollingLoopGuard =
878
+ B.CreateCondBr (BranchVal, RemainderLoop, UnrollingLoop, BranchWeights);
879
+ if (OriginalLoopProb && UseEpilogRemainder) {
880
+ // The original loop's first iteration always happens. Compute the
881
+ // probability of the original loop executing Count-1 iterations after that
882
+ // to complete the first iteration of the unrolled loop.
883
+ double ProbOne = *OriginalLoopProb;
884
+ double ProbRest = pow (ProbOne, Count - 1 );
885
+ setBranchProbability (UnrollingLoopGuard, ProbRest,
886
+ /* ForFirstTarget=*/ false );
887
+ }
832
888
PreHeaderBR->eraseFromParent ();
833
889
if (DT) {
834
890
if (UseEpilogRemainder)
@@ -855,9 +911,10 @@ bool llvm::UnrollRuntimeLoopRemainder(
855
911
// iterations. This function adds the appropriate CFG connections.
856
912
BasicBlock *InsertBot = UseEpilogRemainder ? LatchExit : PrologExit;
857
913
BasicBlock *InsertTop = UseEpilogRemainder ? EpilogPreHeader : PrologPreHeader;
858
- Loop *remainderLoop = CloneLoopBlocks (
859
- L, ModVal, UseEpilogRemainder, UnrollRemainder, InsertTop, InsertBot,
860
- NewPreHeader, NewBlocks, LoopBlocks, VMap, DT, LI, Count);
914
+ Loop *remainderLoop =
915
+ CloneLoopBlocks (L, ModVal, UseEpilogRemainder, UnrollRemainder, InsertTop,
916
+ InsertBot, NewPreHeader, NewBlocks, LoopBlocks, VMap, DT,
917
+ LI, Count, OriginalTripCount, OriginalLoopProb);
861
918
862
919
// Insert the cloned blocks into the function.
863
920
F->splice (InsertBot->getIterator (), F, NewBlocks[0 ]->getIterator (), F->end ());
@@ -956,7 +1013,8 @@ bool llvm::UnrollRuntimeLoopRemainder(
956
1013
// Connect the epilog code to the original loop and update the
957
1014
// PHI functions.
958
1015
ConnectEpilog (L, ModVal, NewExit, LatchExit, PreHeader, EpilogPreHeader,
959
- NewPreHeader, VMap, DT, LI, PreserveLCSSA, *SE, Count, *AC);
1016
+ NewPreHeader, VMap, DT, LI, PreserveLCSSA, *SE, Count, *AC,
1017
+ OriginalLoopProb);
960
1018
961
1019
// Update counter in loop for unrolling.
962
1020
// Use an incrementing IV. Pre-incr/post-incr is backedge/trip count.
0 commit comments