Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions llvm/include/llvm/Transforms/Utils/LoopRotationUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef LLVM_TRANSFORMS_UTILS_LOOPROTATIONUTILS_H
#define LLVM_TRANSFORMS_UTILS_LOOPROTATIONUTILS_H

#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Compiler.h"

namespace llvm {
Expand All @@ -32,12 +33,14 @@ class TargetTransformInfo;
/// header. If the loop header's size exceeds the threshold, the loop rotation
/// will give up. The flag IsUtilMode controls the heuristic used in the
/// LoopRotation. If it is true, the profitability heuristic will be ignored.
LLVM_ABI bool LoopRotation(Loop *L, LoopInfo *LI,
const TargetTransformInfo *TTI, AssumptionCache *AC,
DominatorTree *DT, ScalarEvolution *SE,
MemorySSAUpdater *MSSAU, const SimplifyQuery &SQ,
bool RotationOnly, unsigned Threshold,
bool IsUtilMode, bool PrepareForLTO = false);
/// The ProfitabilityCheck function overrides the general profitability check.
LLVM_ABI bool LoopRotation(
Loop *L, LoopInfo *LI, const TargetTransformInfo *TTI, AssumptionCache *AC,
DominatorTree *DT, ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
const SimplifyQuery &SQ, bool RotationOnly, unsigned Threshold,
bool IsUtilMode, bool PrepareForLTO = false,
function_ref<bool(Loop *, ScalarEvolution *)> ProfitabilityCheck =
[](Loop *, ScalarEvolution *) { return false; });

} // namespace llvm

Expand Down
5 changes: 4 additions & 1 deletion llvm/include/llvm/Transforms/Utils/UnrollLoop.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,14 @@ LLVM_ABI const Loop *addClonedBlockToLoopInfo(BasicBlock *OriginalBB,
LoopInfo *LI,
NewLoopsMap &NewLoops);

/// Represents the result of a \c UnrollLoop invocation.
/// Represents the result of a \c UnrollLoop and \c UnrollAndJamLoop invocation.
enum class LoopUnrollResult {
/// The loop was not modified.
Unmodified,

/// The loop was modified, but not unrolled.
Modified,

/// The loop was partially unrolled -- we still have a loop, but with a
/// smaller trip count. We may also have emitted epilogue loop if the loop
/// had a non-constant trip count.
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1357,8 +1357,9 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE,
ULO.RuntimeUnrollMultiExit = UP.RuntimeUnrollMultiExit;
LoopUnrollResult UnrollResult = UnrollLoop(
L, ULO, LI, &SE, &DT, &AC, &TTI, &ORE, PreserveLCSSA, &RemainderLoop, AA);
if (UnrollResult == LoopUnrollResult::Unmodified)
return LoopUnrollResult::Unmodified;
if (UnrollResult == LoopUnrollResult::Unmodified ||
UnrollResult == LoopUnrollResult::Modified)
return UnrollResult;

if (RemainderLoop) {
std::optional<MDNode *> RemainderLoopID =
Expand Down
28 changes: 16 additions & 12 deletions llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,19 @@ class LoopRotate {
bool RotationOnly;
bool IsUtilMode;
bool PrepareForLTO;
function_ref<bool(Loop *, ScalarEvolution *)> ProfitabilityCheck;

public:
LoopRotate(unsigned MaxHeaderSize, LoopInfo *LI,
const TargetTransformInfo *TTI, AssumptionCache *AC,
DominatorTree *DT, ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
const SimplifyQuery &SQ, bool RotationOnly, bool IsUtilMode,
bool PrepareForLTO)
bool PrepareForLTO,
function_ref<bool(Loop *, ScalarEvolution *)> ProfitabilityCheck)
: MaxHeaderSize(MaxHeaderSize), LI(LI), TTI(TTI), AC(AC), DT(DT), SE(SE),
MSSAU(MSSAU), SQ(SQ), RotationOnly(RotationOnly),
IsUtilMode(IsUtilMode), PrepareForLTO(PrepareForLTO) {}
IsUtilMode(IsUtilMode), PrepareForLTO(PrepareForLTO),
ProfitabilityCheck(ProfitabilityCheck) {}
bool processLoop(Loop *L);

private:
Expand Down Expand Up @@ -440,9 +443,9 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) {

// Rotate if either the loop latch does *not* exit the loop, or if the loop
// latch was just simplified. Or if we think it will be profitable.
if (L->isLoopExiting(OrigLatch) && !SimplifiedLatch && IsUtilMode == false &&
!profitableToRotateLoopExitingLatch(L) &&
!canRotateDeoptimizingLatchExit(L))
if (L->isLoopExiting(OrigLatch) && !SimplifiedLatch &&
IsUtilMode == false && !profitableToRotateLoopExitingLatch(L) &&
!canRotateDeoptimizingLatchExit(L) && !ProfitabilityCheck(L, SE))
return Rotated;

// Check size of original header and reject loop if it is very big or we can't
Expand Down Expand Up @@ -1053,13 +1056,14 @@ bool LoopRotate::processLoop(Loop *L) {


/// The utility to convert a loop into a loop with bottom test.
bool llvm::LoopRotation(Loop *L, LoopInfo *LI, const TargetTransformInfo *TTI,
AssumptionCache *AC, DominatorTree *DT,
ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
const SimplifyQuery &SQ, bool RotationOnly = true,
unsigned Threshold = unsigned(-1),
bool IsUtilMode = true, bool PrepareForLTO) {
bool llvm::LoopRotation(
Loop *L, LoopInfo *LI, const TargetTransformInfo *TTI, AssumptionCache *AC,
DominatorTree *DT, ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
const SimplifyQuery &SQ, bool RotationOnly = true,
unsigned Threshold = unsigned(-1), bool IsUtilMode = true,
bool PrepareForLTO,
function_ref<bool(Loop *, ScalarEvolution *)> ProfitabilityCheck) {
LoopRotate LR(Threshold, LI, TTI, AC, DT, SE, MSSAU, SQ, RotationOnly,
IsUtilMode, PrepareForLTO);
IsUtilMode, PrepareForLTO, ProfitabilityCheck);
return LR.processLoop(L);
}
162 changes: 103 additions & 59 deletions llvm/lib/Transforms/Utils/LoopUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/LoopRotationUtils.h"
#include "llvm/Transforms/Utils/LoopSimplify.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/Transforms/Utils/SimplifyIndVar.h"
Expand Down Expand Up @@ -486,12 +487,7 @@ llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,

// All these values should be taken only after peeling because they might have
// changed.
BasicBlock *Preheader = L->getLoopPreheader();
BasicBlock *Header = L->getHeader();
BasicBlock *LatchBlock = L->getLoopLatch();
SmallVector<BasicBlock *, 4> ExitBlocks;
L->getExitBlocks(ExitBlocks);
std::vector<BasicBlock *> OriginalLoopBlocks = L->getBlocks();

const unsigned MaxTripCount = SE->getSmallConstantMaxTripCount(L);
const bool MaxOrZero = SE->isBackedgeTakenCountMaxOrZero(L);
Expand All @@ -504,42 +500,6 @@ llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
if (MaxTripCount && ULO.Count > MaxTripCount)
ULO.Count = MaxTripCount;

struct ExitInfo {
unsigned TripCount;
unsigned TripMultiple;
unsigned BreakoutTrip;
bool ExitOnTrue;
BasicBlock *FirstExitingBlock = nullptr;
SmallVector<BasicBlock *> ExitingBlocks;
};
DenseMap<BasicBlock *, ExitInfo> ExitInfos;
SmallVector<BasicBlock *, 4> ExitingBlocks;
L->getExitingBlocks(ExitingBlocks);
for (auto *ExitingBlock : ExitingBlocks) {
// The folding code is not prepared to deal with non-branch instructions
// right now.
auto *BI = dyn_cast<BranchInst>(ExitingBlock->getTerminator());
if (!BI)
continue;

ExitInfo &Info = ExitInfos[ExitingBlock];
Info.TripCount = SE->getSmallConstantTripCount(L, ExitingBlock);
Info.TripMultiple = SE->getSmallConstantTripMultiple(L, ExitingBlock);
if (Info.TripCount != 0) {
Info.BreakoutTrip = Info.TripCount % ULO.Count;
Info.TripMultiple = 0;
} else {
Info.BreakoutTrip = Info.TripMultiple =
(unsigned)std::gcd(ULO.Count, Info.TripMultiple);
}
Info.ExitOnTrue = !L->contains(BI->getSuccessor(0));
Info.ExitingBlocks.push_back(ExitingBlock);
LLVM_DEBUG(dbgs() << " Exiting block %" << ExitingBlock->getName()
<< ": TripCount=" << Info.TripCount
<< ", TripMultiple=" << Info.TripMultiple
<< ", BreakoutTrip=" << Info.BreakoutTrip << "\n");
}

// Are we eliminating the loop control altogether? Note that we can know
// we're eliminating the backedge without knowing exactly which iteration
// of the unrolled body exits.
Expand All @@ -552,17 +512,6 @@ llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
if (CompletelyUnroll)
ULO.Runtime = false;

// Go through all exits of L and see if there are any phi-nodes there. We just
// conservatively assume that they're inserted to preserve LCSSA form, which
// means that complete unrolling might break this form. We need to either fix
// it in-place after the transformation, or entirely rebuild LCSSA. TODO: For
// now we just recompute LCSSA for the outer loop, but it should be possible
// to fix it in-place.
bool NeedToFixLCSSA =
PreserveLCSSA && CompletelyUnroll &&
any_of(ExitBlocks,
[](const BasicBlock *BB) { return isa<PHINode>(BB->begin()); });

// The current loop unroll pass can unroll loops that have
// (1) single latch; and
// (2a) latch is unconditional; or
Expand All @@ -587,21 +536,116 @@ llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
UnrollRuntimeEpilog.getNumOccurrences() ? UnrollRuntimeEpilog
: isEpilogProfitable(L);

if (ULO.Runtime &&
!UnrollRuntimeLoopRemainder(L, ULO.Count, ULO.AllowExpensiveTripCount,
EpilogProfitability, ULO.UnrollRemainder,
ULO.ForgetAllSCEV, LI, SE, DT, AC, TTI,
PreserveLCSSA, ULO.SCEVExpansionBudget,
ULO.RuntimeUnrollMultiExit, RemainderLoop)) {
bool LoopRotated = false;
bool ReminderUnrolled = false;
if (ULO.Runtime) {
// Call unroll with disabled rotation, to see if it is possible without it.
ReminderUnrolled = UnrollRuntimeLoopRemainder(
L, ULO.Count, ULO.AllowExpensiveTripCount, EpilogProfitability,
ULO.UnrollRemainder, ULO.ForgetAllSCEV, LI, SE, DT, AC, TTI,
PreserveLCSSA, ULO.SCEVExpansionBudget, ULO.RuntimeUnrollMultiExit,
RemainderLoop);

// If unroll is not possible, then try with loop rotation.
if (!ReminderUnrolled) {
BasicBlock *OrigHeader = L->getHeader();
BranchInst *BI = dyn_cast<BranchInst>(OrigHeader->getTerminator());
if (BI && !BI->isUnconditional() &&
isa<SCEVCouldNotCompute>(SE->getExitCount(L, L->getLoopLatch())) &&
!isa<SCEVCouldNotCompute>(SE->getExitCount(L, OrigHeader))) {
LLVM_DEBUG(
dbgs() << " Rotating loop to make the exit count computable.\n");
SimplifyQuery SQ{OrigHeader->getDataLayout()};
SQ.TLI = nullptr;
SQ.DT = DT;
SQ.AC = AC;
LoopRotated =
llvm::LoopRotation(L, LI, TTI, AC, DT, SE,
/*MemorySSAUpdater*/ nullptr, SQ,
/*RotationOnly*/ false, /*Threshold*/ 16,
/*IsUtilMode*/ false, /*PrepareForLTO*/ false,
[](Loop *, ScalarEvolution *) { return true; });
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be done here or could UnrollRuntimeLoopRemainder do it if needed? This would make it easier to keep checks in sync and could avoid calling UnrollRuntimeLoopRemainder twice?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UnrollRuntimeLoopReminder could do it, but it will still need to do all the checks it does after it does the rotation, this means UnrollRuntimeLoopReminder would have to take a flag if rotation is allowed (so that UnrollAndJam does not trigger rotation) and then it would do all the unroll checks, if they failed it will rotate and then do the checks again (then unrolling).
So doing it inside of reminder unroll, will copy the checks, which is similar to current stare where the call is copied.
But I see your point and it will make the reminder unroll more powerful and could be used later on in other parts, so I can change that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there that many checks? I'm not sure we need to re-run the checks after rotation.

Can't we just do the checks once, using either the latch or header, if latch isn't exiting?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(LoopUnroll itself works on both rotated and unrotated loops in a similar way)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that I have tested it, I think having it this way is cleaner and nicer. When I tried to move the rotation inside of the UnrollRuntimeLoopRemainder I tried separating the legality checks into a function/lambda to call, but it relies no a lot of resources (Latch, Header, LatchBR, LatchExit, TripCountSC...) which are also used later on in the unrolling.
So it currently initializes some values, does a check, intializes other, does a check and so forth and then it unrolls. If I was to add rotation to this, then even with lambda that captures all, I will have to initialize all of them and them do the checks inside of the lambda, rotate and then re-initialize all of the values and run the check again. It becomes a mess and less efficient than just calling UnrollRuntimeLoopRemainder twice, since it currently can do early exit without initializing all the values.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for checking @mark-sed! If you still have it, could you share the diff of how that would look like?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not, but I did try to quickly put it together now, it will not compile and some stuff could be optimized, but it would be something like this and I much prefer the current implementation. The reason why it is like this is because the unrolling uses those values that are being checked and many of those need to be recalculated after the rotation.

It would be something like this:

diff --git a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
index 36c976e23eed..d1a41e6e73a5 100644
--- a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
@@ -616,106 +616,157 @@ bool llvm::UnrollRuntimeLoopRemainder(
     return false;
   }
 
+  // Use Scalar Evolution to compute the trip count. This allows more loops to
+  // be unrolled than relying on induction var simplification.
+  if (!SE)
+    return false;
+
   // Guaranteed by LoopSimplifyForm.
   BasicBlock *Latch = L->getLoopLatch();
   BasicBlock *Header = L->getHeader();
 
   BranchInst *LatchBR = cast<BranchInst>(Latch->getTerminator());
-
-  if (!LatchBR || LatchBR->isUnconditional()) {
-    // The loop-rotate pass can be helpful to avoid this in many cases.
-    LLVM_DEBUG(
-        dbgs()
-        << "Loop latch not terminated by a conditional branch.\n");
-    return false;
-  }
-
   unsigned ExitIndex = LatchBR->getSuccessor(0) == Header ? 1 : 0;
   BasicBlock *LatchExit = LatchBR->getSuccessor(ExitIndex);
 
-  if (L->contains(LatchExit)) {
-    // Cloning the loop basic blocks (`CloneLoopBlocks`) requires that one of the
-    // targets of the Latch be an exit block out of the loop.
-    LLVM_DEBUG(
-        dbgs()
-        << "One of the loop latch successors must be the exit block.\n");
-    return false;
-  }
-
   // These are exit blocks other than the target of the latch exiting block.
   SmallVector<BasicBlock *, 4> OtherExits;
   L->getUniqueNonLatchExitBlocks(OtherExits);
-  // Support only single exit and exiting block unless multi-exit loop
-  // unrolling is enabled.
-  if (!L->getExitingBlock() || OtherExits.size()) {
-    // We rely on LCSSA form being preserved when the exit blocks are transformed.
-    // (Note that only an off-by-default mode of the old PM disables PreserveLCCA.)
-    if (!PreserveLCSSA)
+
+  BasicBlock *PreHeader = L->getLoopPreheader();
+  BranchInst *PreHeaderBR = cast<BranchInst>(PreHeader->getTerminator());
+  const DataLayout &DL = Header->getDataLayout();
+
+  auto canUnrollReminder = [&]() {
+    // Only unroll loops with a computable trip count.
+    // We calculate the backedge count by using getExitCount on the Latch block,
+    // which is proven to be the only exiting block in this loop. This is same as
+    // calculating getBackedgeTakenCount on the loop (which computes SCEV for all
+    // exiting blocks).
+    const SCEV *BECountSC = SE->getExitCount(L, Latch);
+
+    // Add 1 since the backedge count doesn't include the first loop iteration.
+    // (Note that overflow can occur, this is handled explicitly below)
+    SCEV *TripCountSC =
+        SE->getAddExpr(BECountSC, SE->getConstant(BECountSC->getType(), 1));
+
+    if (!LatchBR || LatchBR->isUnconditional()) {
+      // The loop-rotate pass can be helpful to avoid this in many cases.
+      LLVM_DEBUG(
+          dbgs()
+          << "Loop latch not terminated by a conditional branch.\n");
       return false;
+    }
 
-    // Priority goes to UnrollRuntimeMultiExit if it's supplied.
-    if (UnrollRuntimeMultiExit.getNumOccurrences()) {
-      if (!UnrollRuntimeMultiExit)
-        return false;
-    } else {
-      // Otherwise perform multi-exit unrolling, if either the target indicates
-      // it is profitable or the general profitability heuristics apply.
-      if (!RuntimeUnrollMultiExit &&
-          !canProfitablyRuntimeUnrollMultiExitLoop(L, BPI, TTI, OtherExits,
-                                                   LatchExit,
-                                                   UseEpilogRemainder)) {
-        LLVM_DEBUG(dbgs() << "Multiple exit/exiting blocks in loop and "
-                             "multi-exit unrolling not enabled!\n");
+    if (L->contains(LatchExit)) {
+      // Cloning the loop basic blocks (`CloneLoopBlocks`) requires that one of the
+      // targets of the Latch be an exit block out of the loop.
+      LLVM_DEBUG(
+          dbgs()
+          << "One of the loop latch successors must be the exit block.\n");
+      return false;
+    }
+
+    // Support only single exit and exiting block unless multi-exit loop
+    // unrolling is enabled.
+    if (!L->getExitingBlock() || OtherExits.size()) {
+      // We rely on LCSSA form being preserved when the exit blocks are transformed.
+      // (Note that only an off-by-default mode of the old PM disables PreserveLCCA.)
+      if (!PreserveLCSSA)
         return false;
+
+      // Priority goes to UnrollRuntimeMultiExit if it's supplied.
+      if (UnrollRuntimeMultiExit.getNumOccurrences()) {
+        if (!UnrollRuntimeMultiExit)
+          return false;
+      } else {
+        // Otherwise perform multi-exit unrolling, if either the target indicates
+        // it is profitable or the general profitability heuristics apply.
+        if (!RuntimeUnrollMultiExit &&
+            !canProfitablyRuntimeUnrollMultiExitLoop(L, BPI, TTI, OtherExits,
+                                                    LatchExit,
+                                                    UseEpilogRemainder)) {
+          LLVM_DEBUG(dbgs() << "Multiple exit/exiting blocks in loop and "
+                              "multi-exit unrolling not enabled!\n");
+          return false;
+        }
       }
     }
-  }
-  // Use Scalar Evolution to compute the trip count. This allows more loops to
-  // be unrolled than relying on induction var simplification.
-  if (!SE)
-    return false;
 
-  // Only unroll loops with a computable trip count.
-  // We calculate the backedge count by using getExitCount on the Latch block,
-  // which is proven to be the only exiting block in this loop. This is same as
-  // calculating getBackedgeTakenCount on the loop (which computes SCEV for all
-  // exiting blocks).
-  const SCEV *BECountSC = SE->getExitCount(L, Latch);
-  if (isa<SCEVCouldNotCompute>(BECountSC)) {
-    LLVM_DEBUG(dbgs() << "Could not compute exit block SCEV\n");
-    return false;
+    if (isa<SCEVCouldNotCompute>(BECountSC)) {
+      LLVM_DEBUG(dbgs() << "Could not compute exit block SCEV\n");
+      return false;
+    }
+
+    unsigned BEWidth = cast<IntegerType>(BECountSC->getType())->getBitWidth();
+
+    if (isa<SCEVCouldNotCompute>(TripCountSC)) {
+      LLVM_DEBUG(dbgs() << "Could not compute trip count SCEV.\n");
+      return false;
+    }
+
+    SCEVExpander Expander(*SE, DL, "loop-unroll");
+    if (!AllowExpensiveTripCount &&
+        Expander.isHighCostExpansion(TripCountSC, L, SCEVExpansionBudget, TTI,
+                                    PreHeaderBR)) {
+      LLVM_DEBUG(dbgs() << "High cost for expanding trip count scev!\n");
+      return false;
+    }
+
+    // This constraint lets us deal with an overflowing trip count easily; see the
+    // comment on ModVal below.
+    if (Log2_32(Count) > BEWidth) {
+      LLVM_DEBUG(
+          dbgs()
+          << "Count failed constraint on overflow trip count calculation.\n");
+      return false;
+    }
+
+    return true;
+  };
+
+  bool LoopRotated = false;
+  if (!canUnrollReminder()) {
+    BasicBlock *OrigHeader = L->getHeader();
+    BranchInst *BI = dyn_cast<BranchInst>(OrigHeader->getTerminator());
+    if (BI && !BI->isUnconditional() &&
+        isa<SCEVCouldNotCompute>(SE->getExitCount(L, L->getLoopLatch())) &&
+        !isa<SCEVCouldNotCompute>(SE->getExitCount(L, OrigHeader))) {
+      LLVM_DEBUG(
+          dbgs() << "  Rotating loop to make the exit count computable.\n");
+      SimplifyQuery SQ{OrigHeader->getDataLayout()};
+      SQ.TLI = nullptr;
+      SQ.DT = DT;
+      SQ.AC = AC;
+      LoopRotated =
+          llvm::LoopRotation(L, LI, TTI, AC, DT, SE,
+                              /*MemorySSAUpdater*/ nullptr, SQ,
+                              /*RotationOnly*/ false, /*Threshold*/ 16,
+                              /*IsUtilMode*/ false, /*PrepareForLTO*/ false,
+                              [](Loop *, ScalarEvolution *) { return true; });
+    }
   }
+  if (!LoopRotated)
+    return false; 
 
-  unsigned BEWidth = cast<IntegerType>(BECountSC->getType())->getBitWidth();
+  Latch = L->getLoopLatch();
+  Header = L->getHeader();
 
-  // Add 1 since the backedge count doesn't include the first loop iteration.
-  // (Note that overflow can occur, this is handled explicitly below)
-  const SCEV *TripCountSC =
+  LatchBR = cast<BranchInst>(Latch->getTerminator());
+  ExitIndex = LatchBR->getSuccessor(0) == Header ? 1 : 0;
+  LatchExit = LatchBR->getSuccessor(ExitIndex);
+
+  L->getUniqueNonLatchExitBlocks(OtherExits);
+  const SCEV *BECountSC = SE->getExitCount(L, Latch);
+  TripCountSC =
       SE->getAddExpr(BECountSC, SE->getConstant(BECountSC->getType(), 1));
-  if (isa<SCEVCouldNotCompute>(TripCountSC)) {
-    LLVM_DEBUG(dbgs() << "Could not compute trip count SCEV.\n");
-    return false;
-  }
 
-  BasicBlock *PreHeader = L->getLoopPreheader();
-  BranchInst *PreHeaderBR = cast<BranchInst>(PreHeader->getTerminator());
-  const DataLayout &DL = Header->getDataLayout();
+  PreHeader = L->getLoopPreheader();
+  PreHeaderBR = cast<BranchInst>(PreHeader->getTerminator());
   SCEVExpander Expander(*SE, DL, "loop-unroll");
-  if (!AllowExpensiveTripCount &&
-      Expander.isHighCostExpansion(TripCountSC, L, SCEVExpansionBudget, TTI,
-                                   PreHeaderBR)) {
-    LLVM_DEBUG(dbgs() << "High cost for expanding trip count scev!\n");
-    return false;
-  }
 
-  // This constraint lets us deal with an overflowing trip count easily; see the
-  // comment on ModVal below.
-  if (Log2_32(Count) > BEWidth) {
-    LLVM_DEBUG(
-        dbgs()
-        << "Count failed constraint on overflow trip count calculation.\n");
+  if (!canUnrollReminder())
     return false;
-  }
 
   // Loop structure is the following:
   //

}
if (LoopRotated) {
// Loop was rotated, try unrolling.
ReminderUnrolled = UnrollRuntimeLoopRemainder(
L, ULO.Count, ULO.AllowExpensiveTripCount, EpilogProfitability,
ULO.UnrollRemainder, ULO.ForgetAllSCEV, LI, SE, DT, AC, TTI,
PreserveLCSSA, ULO.SCEVExpansionBudget, ULO.RuntimeUnrollMultiExit,
RemainderLoop);
}
}
// Latch block needs to be updated.
LatchBlock = L->getLoopLatch();
LatchIsExiting = L->isLoopExiting(LatchBlock);
}

if (ULO.Runtime && !ReminderUnrolled) {
if (ULO.Force)
ULO.Runtime = false;
else {
LLVM_DEBUG(dbgs() << "Won't unroll; remainder loop could not be "
"generated when assuming runtime trip count\n");
return LoopUnrollResult::Unmodified;
// Loop might have been rotated inside of UnrollRuntimeLoopRemainder and
// this needs to be propagated.
return LoopRotated ? LoopUnrollResult::Modified
: LoopUnrollResult::Unmodified;
}
}

BasicBlock *Preheader = L->getLoopPreheader();
BasicBlock *Header = L->getHeader();
SmallVector<BasicBlock *, 4> ExitBlocks;
L->getExitBlocks(ExitBlocks);
std::vector<BasicBlock *> OriginalLoopBlocks = L->getBlocks();

// Go through all exits of L and see if there are any phi-nodes there. We just
// conservatively assume that they're inserted to preserve LCSSA form, which
// means that complete unrolling might break this form. We need to either fix
// it in-place after the transformation, or entirely rebuild LCSSA. TODO: For
// now we just recompute LCSSA for the outer loop, but it should be possible
// to fix it in-place.
bool NeedToFixLCSSA =
PreserveLCSSA && CompletelyUnroll &&
any_of(ExitBlocks,
[](const BasicBlock *BB) { return isa<PHINode>(BB->begin()); });

struct ExitInfo {
unsigned TripCount;
unsigned TripMultiple;
unsigned BreakoutTrip;
bool ExitOnTrue;
BasicBlock *FirstExitingBlock = nullptr;
SmallVector<BasicBlock *> ExitingBlocks;
};
DenseMap<BasicBlock *, ExitInfo> ExitInfos;
SmallVector<BasicBlock *, 4> ExitingBlocks;
L->getExitingBlocks(ExitingBlocks);
for (auto *ExitingBlock : ExitingBlocks) {
// The folding code is not prepared to deal with non-branch instructions
// right now.
auto *BI = dyn_cast<BranchInst>(ExitingBlock->getTerminator());
if (!BI)
continue;

ExitInfo &Info = ExitInfos[ExitingBlock];
Info.TripCount = SE->getSmallConstantTripCount(L, ExitingBlock);
Info.TripMultiple = SE->getSmallConstantTripMultiple(L, ExitingBlock);
if (Info.TripCount != 0) {
Info.BreakoutTrip = Info.TripCount % ULO.Count;
Info.TripMultiple = 0;
} else {
Info.BreakoutTrip = Info.TripMultiple =
(unsigned)std::gcd(ULO.Count, Info.TripMultiple);
}
Info.ExitOnTrue = !L->contains(BI->getSuccessor(0));
Info.ExitingBlocks.push_back(ExitingBlock);
LLVM_DEBUG(dbgs() << " Exiting block %" << ExitingBlock->getName()
<< ": TripCount=" << Info.TripCount
<< ", TripMultiple=" << Info.TripMultiple
<< ", BreakoutTrip=" << Info.BreakoutTrip << "\n");
}

using namespace ore;
// Report the unrolling decision.
if (CompletelyUnroll) {
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/Transforms/LoopUnroll/full-unroll-avoid-partial.ll
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
; LOOP-UNROLL-LABEL: Loop Unroll: F[pragma_unroll] Loop %for.body
; LOOP-UNROLL-NEXT: Loop Size = 9
; LOOP-UNROLL-NEXT: runtime unrolling with count: 8
; LOOP-UNROLL-NEXT: Exiting block %for.body: TripCount=0, TripMultiple=1, BreakoutTrip=1
; LOOP-UNROLL-NEXT: Trying runtime unrolling on Loop:
; LOOP-UNROLL-NEXT: Loop at depth 1 containing: %for.body<header><latch><exiting>
; LOOP-UNROLL-NEXT: Using epilog remainder.
; LOOP-UNROLL-NEXT: Exiting block %for.body: TripCount=0, TripMultiple=8, BreakoutTrip=8
; LOOP-UNROLL-NEXT: UNROLLING loop %for.body by 8 with run-time trip count!

; LOOP-UNROLL-FULL-LABEL: Loop Unroll: F[pragma_unroll] Loop %for.body
Expand Down Expand Up @@ -49,10 +49,10 @@ for.body: ; preds = %for.body.preheader,

; LOOP-UNROLL-LABEL: Loop Unroll: F[pragma_unroll_count1] Loop %for.body
; LOOP-UNROLL-NEXT: Loop Size = 9
; LOOP-UNROLL-NEXT: Exiting block %for.body: TripCount=0, TripMultiple=1, BreakoutTrip=1
; LOOP-UNROLL-NEXT: Trying runtime unrolling on Loop:
; LOOP-UNROLL-NEXT: Loop at depth 1 containing: %for.body<header><latch><exiting>
; LOOP-UNROLL-NEXT: Using epilog remainder.
; LOOP-UNROLL-NEXT: Exiting block %for.body: TripCount=0, TripMultiple=1, BreakoutTrip=1
; LOOP-UNROLL-NEXT: UNROLLING loop %for.body by 5 with run-time trip count!

; LOOP-UNROLL-FULL-LABEL: Loop Unroll: F[pragma_unroll_count1] Loop %for.body
Expand Down
Loading