diff --git a/llvm/include/llvm/Transforms/Utils/UnrollLoop.h b/llvm/include/llvm/Transforms/Utils/UnrollLoop.h index 797c082333a76..9c36ff7568919 100644 --- a/llvm/include/llvm/Transforms/Utils/UnrollLoop.h +++ b/llvm/include/llvm/Transforms/Utils/UnrollLoop.h @@ -17,7 +17,9 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/LazyValueInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Constants.h" #include "llvm/Support/InstructionCost.h" namespace llvm { @@ -110,6 +112,16 @@ void simplifyLoopAfterUnroll(Loop *L, bool SimplifyIVs, LoopInfo *LI, AAResults *AA = nullptr); MDNode *GetUnrollMetadata(MDNode *LoopID, StringRef Name); +LoopUnrollResult tryUnrollLoopIntoSwitch(Loop &L, LazyValueInfo &LVI, + ScalarEvolution &SE, LoopInfo &LI, + DominatorTree &DT, + bool ForgetAllSCEV = false); +LoopUnrollResult UnrollLoopIntoSwitch( + Loop &L, unsigned UnrollCount, Value *SwitchValue, + ConstantInt *FirstSwitchValue, + std::function nextSwitchValue, + ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, + bool ForgetAllSCEV = false); TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences( Loop *L, ScalarEvolution &SE, const TargetTransformInfo &TTI, diff --git a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp index cbc35b6dd4292..01db3deec21db 100644 --- a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp +++ b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -24,6 +24,7 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/LazyValueInfo.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" @@ -1598,6 +1599,7 @@ PreservedAnalyses LoopUnrollPass::run(Function &F, auto &DT = AM.getResult(F); auto &AC = AM.getResult(F); auto &ORE = AM.getResult(F); + LazyValueInfo &LVI = AM.getResult(F); AAResults &AA = AM.getResult(F); LoopAnalysisManager *LAM = nullptr; @@ -1645,17 +1647,21 @@ PreservedAnalyses LoopUnrollPass::run(Function &F, if (PSI && PSI->hasHugeWorkingSetSize()) LocalAllowPeeling = false; std::string LoopName = std::string(L.getName()); + LoopUnrollResult Result = + tryUnrollLoopIntoSwitch(L, LVI, SE, LI, DT, /*PreserveLCSSA*/ true); + // The API here is quite complex to call and we allow to select some // flavors of unrolling during construction time (by setting UnrollOpts). - LoopUnrollResult Result = tryToUnrollLoop( - &L, DT, &LI, SE, TTI, AC, ORE, BFI, PSI, - /*PreserveLCSSA*/ true, UnrollOpts.OptLevel, /*OnlyFullUnroll*/ false, - UnrollOpts.OnlyWhenForced, UnrollOpts.ForgetSCEV, - /*Count*/ std::nullopt, - /*Threshold*/ std::nullopt, UnrollOpts.AllowPartial, - UnrollOpts.AllowRuntime, UnrollOpts.AllowUpperBound, LocalAllowPeeling, - UnrollOpts.AllowProfileBasedPeeling, UnrollOpts.FullUnrollMaxCount, - &AA); + if (Result == LoopUnrollResult::Unmodified) + Result = tryToUnrollLoop( + &L, DT, &LI, SE, TTI, AC, ORE, BFI, PSI, + /*PreserveLCSSA*/ true, UnrollOpts.OptLevel, /*OnlyFullUnroll*/ false, + UnrollOpts.OnlyWhenForced, UnrollOpts.ForgetSCEV, + /*Count*/ std::nullopt, + /*Threshold*/ std::nullopt, UnrollOpts.AllowPartial, + UnrollOpts.AllowRuntime, UnrollOpts.AllowUpperBound, + LocalAllowPeeling, UnrollOpts.AllowProfileBasedPeeling, + UnrollOpts.FullUnrollMaxCount, &AA); Changed |= Result != LoopUnrollResult::Unmodified; // The parent must not be damaged by unrolling! diff --git a/llvm/lib/Transforms/Utils/LoopUnroll.cpp b/llvm/lib/Transforms/Utils/LoopUnroll.cpp index b90addcef69e6..b1709048df5b6 100644 --- a/llvm/lib/Transforms/Utils/LoopUnroll.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnroll.cpp @@ -1105,3 +1105,306 @@ MDNode *llvm::GetUnrollMetadata(MDNode *LoopID, StringRef Name) { } return nullptr; } + +static int64_t getElementCountInRange(const APInt &Lower, const APInt &Upper) { + assert(Lower.getBitWidth() == Upper.getBitWidth() && "Bitwidths must match"); + assert(Lower.getBitWidth() <= 64 && "Bitwidth is too big"); + uint64_t LowerVal = Lower.getZExtValue(); + uint64_t UpperVal = Upper.getZExtValue(); + if (UpperVal >= LowerVal) { + return UpperVal - LowerVal + 1; + } + // The values wrap around + uint64_t MaxVal = ((uint64_t)1 << Lower.getBitWidth()) - 1; + return MaxVal - LowerVal + UpperVal + 1; +} + +LoopUnrollResult llvm::tryUnrollLoopIntoSwitch(Loop &L, LazyValueInfo &LVI, + ScalarEvolution &SE, + LoopInfo &LI, DominatorTree &DT, + bool ForgetAllSCEV) { + auto Bounds = L.getBounds(SE); + if (!Bounds) { + LLVM_DEBUG( + dbgs() << " Can't unroll into switch; loop bounds are unknown.\n"); + return LoopUnrollResult::Unmodified; + } + if (L.getHeader() != L.getLoopLatch()) { + LLVM_DEBUG( + dbgs() << " Can't unroll into switch; loop header != loop latch.\n"); + return LoopUnrollResult::Unmodified; + } + + ConstantInt *FinalIVValue = + dyn_cast_or_null(&Bounds->getFinalIVValue()); + // We can extend it in the future to support more complex cases + if (!FinalIVValue) { + LLVM_DEBUG( + dbgs() + << " Can't unroll into switch; final value of IV is not constant.\n"); + return LoopUnrollResult::Unmodified; + } + ConstantInt *StepSizeValue = + dyn_cast_or_null(Bounds->getStepValue()); + if (!StepSizeValue) { + LLVM_DEBUG( + dbgs() << " Can't unroll into switch; step size is not constant.\n"); + return LoopUnrollResult::Unmodified; + } + auto Direction = Bounds->getDirection(); + if (Direction == Loop::LoopBounds::Direction::Unknown) { + LLVM_DEBUG( + dbgs() << " Can't unroll into switch; direction of IV is unknown.\n"); + return LoopUnrollResult::Unmodified; + } + uint64_t StepSize; + if (Direction == Loop::LoopBounds::Direction::Increasing) { + StepSize = StepSizeValue->getValue().getZExtValue(); + } else { + StepSize = StepSizeValue->getValue().abs().getZExtValue(); + } + if (StepSize > 1) { + // TODO: To handle steps > 1 we need: + // 1) Make sure that FirstSwitchValue is a multiple of StepSize + // 2) Handle the case when SwitchValue is not a multiple of StepSize + // (potentially fallback to unrolled loop) + LLVM_DEBUG(dbgs() << " Can't unroll into switch; step size is too big.\n"); + return LoopUnrollResult::Unmodified; + } + + Value &InitialValue = Bounds->getInitialIVValue(); + if (!InitialValue.getType()->isIntegerTy() || + InitialValue.getType()->getIntegerBitWidth() > 64) { + LLVM_DEBUG(dbgs() << " Can't unroll into switch; loop induction variable " + "is not an integer.\n"); + return LoopUnrollResult::Unmodified; + } + BasicBlock *PreHeader = L.getLoopPreheader(); + if (!PreHeader) { + LLVM_DEBUG( + dbgs() << " Can't unroll into switch; loop has no preheader.\n"); + return LoopUnrollResult::Unmodified; + } + auto InitialValueRange = + LVI.getConstantRangeOnEdge(&InitialValue, PreHeader, L.getHeader()); + if (InitialValueRange.isFullSet()) { + LLVM_DEBUG(dbgs() << " Can't unroll into switch; no bound for number of " + "iterations available.\n"); + return LoopUnrollResult::Unmodified; + } + OverflowingBinaryOperator *StepInstOp = + dyn_cast_or_null(&Bounds->getStepInst()); + assert(StepInstOp && "Step instruction is not overflowing binary operator"); + + if (InitialValueRange.contains(FinalIVValue->getValue()) && + !StepInstOp->hasNoSignedWrap() && !StepInstOp->hasNoUnsignedWrap()) { + LLVM_DEBUG(dbgs() << " Can't unroll into switch; cannot establish upper " + "on the number of iterations.\n"); + return LoopUnrollResult::Unmodified; + } + + uint64_t UnrollCount = 0; + ConstantInt *FirstSwitchValue; + if (Direction == Loop::LoopBounds::Direction::Decreasing) { + const APInt &UpperBound = InitialValueRange.getUpper(); + UnrollCount = + (getElementCountInRange(FinalIVValue->getValue(), UpperBound) - 1) / + StepSize; + FirstSwitchValue = dyn_cast( + ConstantInt::get(FinalIVValue->getType(), + UpperBound - APInt(UpperBound.getBitWidth(), 1))); + } else { + const APInt &LowerBound = InitialValueRange.getLower(); + UnrollCount = + getElementCountInRange(LowerBound, FinalIVValue->getValue()) / StepSize; + FirstSwitchValue = dyn_cast( + ConstantInt::get(FinalIVValue->getType(), LowerBound)); + } + LLVM_DEBUG(dbgs() << " Unroll count: " << UnrollCount << "\n"); + // How should we determin the max numbe of iterations? + if (UnrollCount > 20) { + LLVM_DEBUG( + dbgs() + << " Can't unroll into switch; number of iterations is too big.\n"); + return LoopUnrollResult::Unmodified; + } + + return UnrollLoopIntoSwitch( + L, UnrollCount, &InitialValue, FirstSwitchValue, + [&](ConstantInt *CurrentSwitchValue) { + return dyn_cast(ConstantFoldBinaryInstruction( + StepInstOp->getOpcode(), CurrentSwitchValue, StepSizeValue)); + }, + SE, LI, DT, ForgetAllSCEV); +} + +// Unrolls the loop with known upper bound N on the number of iterations +// into N copies of body block chained together and a switch to jump to +// the (N - n)-th block, where n is the number of iterations calculated +// at runtime. +// For example, the following C++ code: +// __builtin_assume(n <= 10); +// do{ +// // loop body +// } while(--n >=0); +// will be transformed into the following: +// switch i32 %n, label %sw.0 [ +// i32 10, label %sw.10 +// i32 9, label %sw.9 +// ... +// i32 1, label %sw.1 +// ] +// sw.10: +// %n.10 = phi i32 [ %n, %entry ] +// ... +// br label %sw.9 +// sw.9: +// %n.9 = phi i32 [ %n, %entry ], [ %n.10, %sw.10 ] +// ... +// br label %sw.8 +// ... +// sw.0: +// %n.0 = phi i32 [ %n, %entry ], [ %n.1, %sw.1] +// ... +// br label %exit +// +// UnrollCount is the number of iterations to unroll the loop into (11 in the +// example above). +// SwitchValue value that should map to the starting block (n in +// the example above). +// FirstSwitchValue is the value of SwitchValue for the +// first iteration (10 in the example above). +// nextSwitchValue is a function that +// takes the value of SwitchValue for the previous iteration and returns the +// value of SwitchValue for the next iteration (n - 1 in the example above). +LoopUnrollResult llvm::UnrollLoopIntoSwitch( + Loop &L, unsigned UnrollCount, Value *SwitchValue, + ConstantInt *FirstSwitchValue, + std::function nextSwitchValue, + ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, bool ForgetAllSCEV) { + if (UnrollCount <= 1) { + return LoopUnrollResult::Unmodified; + } + if (!SwitchValue->getType()->isIntegerTy()) { + LLVM_DEBUG(dbgs() << " Can't unroll into switch; loop induction variable " + "is not an integer.\n"); + return LoopUnrollResult::Unmodified; + } + if (!L.getLoopPreheader()) { + LLVM_DEBUG( + dbgs() + << " Can't unroll into switch; loop preheader-insertion failed.\n"); + return LoopUnrollResult::Unmodified; + } + + if (!L.getLoopLatch()) { + LLVM_DEBUG( + dbgs() + << " Can't unroll into switch; loop exit-block-insertion failed.\n"); + return LoopUnrollResult::Unmodified; + } + + // Loops with indirectbr cannot be cloned. + if (!L.isSafeToClone()) { + LLVM_DEBUG( + dbgs() << " Can't unroll into switch; Loop body cannot be cloned.\n"); + return LoopUnrollResult::Unmodified; + } + + if (L.getHeader()->hasAddressTaken()) { + // The loop-rotate pass can be helpful to avoid this in many cases. + LLVM_DEBUG( + dbgs() << " Won't unroll loop: address of header block is taken.\n"); + return LoopUnrollResult::Unmodified; + } + + if (!L.getExitBlock()) { + LLVM_DEBUG(dbgs() << " Can't unroll; loop has multiple exit blocks.\n"); + return LoopUnrollResult::Unmodified; + } + + if (L.getHeader() != L.getLoopLatch()) { + LLVM_DEBUG( + dbgs() << " Can't unroll into switch; loop header != loop latch.\n"); + return LoopUnrollResult::Unmodified; + } + + BasicBlock *PreHeader = L.getLoopPreheader(); + BasicBlock *Body = L.getHeader(); + BasicBlock *ExitBlock = L.getExitBlock(); + + DenseMap LoopBackedgeValuesOriginal; + for (PHINode &PH : Body->phis()) { + int Index = PH.getBasicBlockIndex(Body); + Value *IncomingValue = PH.getIncomingValue(Index); + assert(IncomingValue && "PHI node has no incoming values from the latch"); + LoopBackedgeValuesOriginal[&PH] = IncomingValue; + PH.removeIncomingValue(Index, true); + } + + DenseMap OrigValueToExitPHIMap; + for (PHINode &PH : ExitBlock->phis()) { + assert(PH.getNumIncomingValues() == 1 && + "Exit block has multiple incoming values"); + OrigValueToExitPHIMap[&PH] = PH.getIncomingValue(0); + } + DenseMap PrevValueToExitPHIMap(OrigValueToExitPHIMap); + + // Replace terminator in preheader with a switch + BranchInst *PreHeaderBranchInst = + dyn_cast_or_null(PreHeader->getTerminator()); + assert(PreHeaderBranchInst && "Preheader terminator is not a branch."); + SwitchInst *SwitchInst = SwitchInst::Create(SwitchValue, Body, UnrollCount); + ReplaceInstWithInst(PreHeaderBranchInst, SwitchInst); + SwitchInst->addCase(FirstSwitchValue, Body); + ConstantInt *PrevSwitchValue = FirstSwitchValue; + + if (ForgetAllSCEV) + SE.forgetAllLoops(); + else { + SE.forgetTopmostLoop(&L); + SE.forgetBlockAndLoopDispositions(); + } + + BasicBlock *PrevBody = Body; + DenseMap LoopBackedgeValuesPrevious( + LoopBackedgeValuesOriginal); + for (unsigned IterCount = 1; IterCount < UnrollCount; IterCount++) { + ValueToValueMapTy VMap; + BasicBlock *NewBody = CloneBasicBlock(Body, VMap, "." + Twine(IterCount)); + auto BlockInsertPt = std::next(PrevBody->getIterator()); + Body->getParent()->insert(BlockInsertPt, NewBody); + if (IterCount != UnrollCount - 1) { + ConstantInt *CurrentSwitchValue = nextSwitchValue(PrevSwitchValue); + SwitchInst->addCase(CurrentSwitchValue, NewBody); + PrevSwitchValue = CurrentSwitchValue; + } + remapInstructionsInBlocks({NewBody}, VMap); + for (auto &[OrigPHINode, PrevIncomingValue] : LoopBackedgeValuesPrevious) { + PHINode *NewPHINode = dyn_cast(VMap[OrigPHINode]); + NewPHINode->addIncoming(PrevIncomingValue, PrevBody); + LoopBackedgeValuesPrevious[OrigPHINode] = + VMap[LoopBackedgeValuesOriginal[OrigPHINode]]; + } + ReplaceInstWithInst(PrevBody->getTerminator(), BranchInst::Create(NewBody)); + PrevBody = NewBody; + DT.addNewBlock(NewBody, PreHeader); + + // Update exit values map + for (auto &[ExitPHINode, ExitValue] : PrevValueToExitPHIMap) { + PrevValueToExitPHIMap[ExitPHINode] = + VMap[OrigValueToExitPHIMap[ExitPHINode]]; + } + } + ReplaceInstWithInst(PrevBody->getTerminator(), BranchInst::Create(ExitBlock)); + for (auto &[ExitPHINode, ExitValue] : PrevValueToExitPHIMap) { + ExitPHINode->setIncomingBlock(0, PrevBody); + ExitPHINode->setIncomingValue(0, ExitValue); + SE.forgetLcssaPhiWithNewPredecessor(&L, ExitPHINode); + } + SwitchInst->setDefaultDest(PrevBody); + LI.erase(&L); + DT.changeImmediateDominator(ExitBlock, PrevBody); + + return LoopUnrollResult::FullyUnrolled; +}