-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[SCEV] Infer loop max trip count from memory accesses #70361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -249,6 +249,10 @@ static cl::opt<bool> UseContextForNoWrapFlagInference( | |
| cl::desc("Infer nuw/nsw flags using context where suitable"), | ||
| cl::init(true)); | ||
|
|
||
| static cl::opt<bool> UseMemoryAccessUBForBEInference( | ||
| "scalar-evolution-infer-max-trip-count-from-memory-access", cl::Hidden, | ||
| cl::desc("Infer loop max trip count from memory access"), cl::init(false)); | ||
|
||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // SCEV class definitions | ||
| //===----------------------------------------------------------------------===// | ||
|
|
@@ -8260,6 +8264,210 @@ ScalarEvolution::getSmallConstantTripMultiple(const Loop *L, | |
| return getSmallConstantTripMultiple(L, ExitCount); | ||
| } | ||
|
|
||
| /// Collect all load/store instructions that must be executed in every iteration | ||
| /// of loop \p L . | ||
| static void | ||
| collectExecLoadStoreInsideLoop(const Loop *L, DominatorTree &DT, | ||
| SmallVector<Instruction *, 4> &MemInsts) { | ||
| // It is difficult to tell if the load/store instruction is executed on every | ||
| // iteration inside an irregular loop. | ||
| if (!L->isLoopSimplifyForm() || !L->isInnermost()) | ||
| return; | ||
|
|
||
| const BasicBlock *LoopLatch = L->getLoopLatch(); | ||
| assert(LoopLatch && "normal form loop doesn't have a latch"); | ||
| assert(L->getExitingBlock() == LoopLatch); | ||
|
|
||
| // We will not continue if sanitizer is enabled. | ||
| const Function *F = LoopLatch->getParent(); | ||
| if (F->hasFnAttribute(Attribute::SanitizeAddress) || | ||
shiltian marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| F->hasFnAttribute(Attribute::SanitizeThread) || | ||
| F->hasFnAttribute(Attribute::SanitizeMemory) || | ||
| F->hasFnAttribute(Attribute::SanitizeHWAddress) || | ||
| F->hasFnAttribute(Attribute::SanitizeMemTag)) | ||
| return; | ||
|
|
||
| for (auto *BB : L->getBlocks()) { | ||
| // We need to make sure that max execution time of MemAccessBB in loop | ||
| // represents latch max excution time. The BB below should be skipped: | ||
| // Entry | ||
| // │ | ||
| // ┌─────▼─────┐ | ||
| // │Loop Header◄─────┐ | ||
| // └──┬──────┬─┘ │ | ||
| // │ │ │ | ||
| // ┌────────▼──┐ ┌─▼─────┐ │ | ||
| // │MemAccessBB│ │OtherBB│ │ | ||
| // └────────┬──┘ └─┬─────┘ │ | ||
| // │ │ │ | ||
| // ┌─▼──────▼─┐ │ | ||
| // │Loop Latch├─────┘ | ||
| // └────┬─────┘ | ||
| // ▼ | ||
| // Exit | ||
| if (!DT.dominates(BB, LoopLatch)) | ||
| continue; | ||
|
|
||
| for (Instruction &I : *BB) { | ||
| if (isa<LoadInst>(&I) || isa<StoreInst>(&I)) | ||
| MemInsts.push_back(&I); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /// Return a SCEV representing the memory size of pointer \p V . | ||
| static const SCEV *getCertainSizeOfMem(const SCEV *V, Type *RTy, | ||
| const DataLayout &DL, | ||
| const TargetLibraryInfo &TLI, | ||
| ScalarEvolution *SE) { | ||
| const SCEVUnknown *PtrBase = dyn_cast<SCEVUnknown>(V); | ||
| if (!PtrBase) | ||
| return nullptr; | ||
| Value *Ptr = PtrBase->getValue(); | ||
| uint64_t Size = 0; | ||
| if (!llvm::getObjectSize(Ptr, Size, DL, &TLI)) | ||
| return nullptr; | ||
| return SE->getConstant(RTy, Size); | ||
| } | ||
|
|
||
| /// Get the range of given index represented by \p AddRec. | ||
| static const SCEV *getIndexRange(const SCEVAddRecExpr *AddRec, | ||
| ScalarEvolution *SE) { | ||
| const SCEV *Range = SE->getConstant(SE->getUnsignedRangeMax(AddRec) - | ||
| SE->getUnsignedRangeMin(AddRec)); | ||
| const SCEV *Step = AddRec->getStepRecurrence(*SE); | ||
| if (SE->isKnownNegative(Step)) | ||
| Step = SE->getNegativeSCEV(Step); | ||
| return SE->getUDivCeilSCEV(Range, Step); | ||
| } | ||
|
|
||
| /// Get the underlying SCEVAddExpr from a cast expression if possible. | ||
| const SCEV *peelCastExpr(const SCEVCastExpr *S, ScalarEvolution *SE) { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have existing utility that can already do this? |
||
| const SCEV *Op = S->getOperand(); | ||
| if (isa<SCEVCouldNotCompute>(Op)) | ||
| return Op; | ||
| if (isa<SCEVAddExpr>(Op)) | ||
| return Op; | ||
| if (isa<SCEVCastExpr>(Op)) | ||
| return peelCastExpr(cast<SCEVCastExpr>(Op), SE); | ||
| return SE->getCouldNotCompute(); | ||
| } | ||
|
|
||
| static Value *peelExt(Value *V) { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have existing utility that can already do this? |
||
| if (isa<ZExtInst>(V) || isa<SExtInst>(V)) | ||
| return peelExt(cast<Instruction>(V)->getOperand(0)); | ||
| return V; | ||
| } | ||
|
|
||
| static bool isIndexInductionVariable(PHINode *InductionVar, Value *Index) { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have existing utility that can already do this?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might be too conservative. It can't handle uses such as |
||
| if (InductionVar == Index) | ||
| return true; | ||
| if (peelExt(Index) == InductionVar) | ||
| return true; | ||
| return false; | ||
| } | ||
|
|
||
| /// Check whether the index can wrap and if we can still infer max trip count | ||
| /// given the max trip count inferred from memory access. | ||
| static const SCEV *checkIndexRange(Value *Ptr, PHINode *InductionVar, | ||
| ScalarEvolution *SE, | ||
| const SCEVConstant *MaxExecCount) { | ||
| SmallVector<const SCEV *> InferCountColl; | ||
| auto *PtrGEP = dyn_cast<GetElementPtrInst>(Ptr); | ||
| if (!PtrGEP) | ||
| return SE->getCouldNotCompute(); | ||
| for (Value *Index : PtrGEP->indices()) { | ||
| Value *V = Index; | ||
| if (!isIndexInductionVariable(InductionVar, Index)) | ||
| continue; | ||
| if (isa<ZExtInst>(V) || isa<SExtInst>(V)) | ||
| V = cast<Instruction>(Index)->getOperand(0); | ||
| auto *SCEV = SE->getSCEV(V); | ||
| if (isa<SCEVCouldNotCompute>(SCEV)) | ||
| return SE->getCouldNotCompute(); | ||
| if (isa<SCEVCastExpr>(SCEV)) { | ||
| SCEV = peelCastExpr(cast<SCEVCastExpr>(SCEV), SE); | ||
| if (isa<SCEVCouldNotCompute>(SCEV)) | ||
| return SE->getCouldNotCompute(); | ||
| } | ||
| auto *AddRec = dyn_cast<SCEVAddRecExpr>(SCEV); | ||
| if (!AddRec) | ||
shiltian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return SE->getCouldNotCompute(); | ||
| auto *IndexRange = getIndexRange(AddRec, SE); | ||
| auto *IndexRangeC = dyn_cast<SCEVConstant>(IndexRange); | ||
| if (!IndexRangeC) | ||
| return SE->getCouldNotCompute(); | ||
| InferCountColl.push_back(IndexRange); | ||
| break; | ||
| } | ||
|
|
||
| if (InferCountColl.empty()) | ||
| return SE->getCouldNotCompute(); | ||
|
|
||
| InferCountColl.push_back(MaxExecCount); | ||
|
|
||
| return SE->getUMinFromMismatchedTypes(InferCountColl); | ||
| } | ||
|
|
||
| ScalarEvolution::ExitLimit | ||
| ScalarEvolution::computeExitLimitFromMemAccessImpl(const Loop *L) { | ||
| SmallVector<Instruction *, 4> MemInsts; | ||
| collectExecLoadStoreInsideLoop(L, DT, MemInsts); | ||
|
|
||
| SmallVector<const SCEV *> InferCountColl; | ||
| const DataLayout &DL = getDataLayout(); | ||
|
|
||
| for (Instruction *I : MemInsts) { | ||
| Value *Ptr = getLoadStorePointerOperand(I); | ||
| assert(Ptr && "empty pointer operand"); | ||
| auto *AddRec = dyn_cast<SCEVAddRecExpr>(getSCEV(Ptr)); | ||
| if (!AddRec || !AddRec->isAffine()) | ||
| continue; | ||
| const SCEV *PtrBase = getPointerBase(AddRec); | ||
| const SCEV *Step = AddRec->getStepRecurrence(*this); | ||
| const SCEV *MemSize = | ||
| getCertainSizeOfMem(PtrBase, Step->getType(), DL, TLI, this); | ||
| if (!MemSize) | ||
| continue; | ||
| if (isKnownNegative(Step)) | ||
| Step = getNegativeSCEV(Step); | ||
| // Now we can infer a max execution time by MemLength/StepLength. | ||
| auto *MaxExecCount = dyn_cast<SCEVConstant>(getUDivCeilSCEV(MemSize, Step)); | ||
| if (!MaxExecCount || MaxExecCount->getAPInt().getActiveBits() > 32) | ||
| continue; | ||
| auto *Res = checkIndexRange(Ptr, L->getInductionVariable(*this), this, | ||
| MaxExecCount); | ||
| if (isa<SCEVCouldNotCompute>(Res)) | ||
| continue; | ||
| InferCountColl.push_back(Res); | ||
| } | ||
|
|
||
| if (InferCountColl.empty()) | ||
| return getCouldNotCompute(); | ||
|
|
||
| const SCEV *Count = getUMinFromMismatchedTypes(InferCountColl); | ||
|
|
||
| return {getCouldNotCompute(), Count, Count, /*MaxOrZero=*/false}; | ||
| } | ||
|
|
||
| ScalarEvolution::ExitLimit | ||
| ScalarEvolution::computeExitLimitFromMemAccessCached(ExitLimitCacheTy &Cache, | ||
| const Loop *L) { | ||
| // We don't really need them but the cache does. | ||
| constexpr Value *ExitCond = nullptr; | ||
| constexpr const bool ExitIfTrue = true; | ||
| constexpr const bool ControlsOnlyExit = true; | ||
| constexpr const bool AllowPredicates = true; | ||
|
|
||
| if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit, | ||
| AllowPredicates)) | ||
| return *MaybeEL; | ||
|
|
||
| ExitLimit EL = computeExitLimitFromMemAccessImpl(L); | ||
| Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL); | ||
| return EL; | ||
| } | ||
|
|
||
| const SCEV *ScalarEvolution::getExitCount(const Loop *L, | ||
| const BasicBlock *ExitingBlock, | ||
| ExitCountKind Kind) { | ||
|
|
@@ -8842,6 +9050,16 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, | |
| if (!Latch || !DT.dominates(ExitingBlock, Latch)) | ||
| return getCouldNotCompute(); | ||
|
|
||
| // FIXME: To make the case more typical, we only analyze loops that have one | ||
| // exiting block and the block must be the latch. It is easier to capture | ||
| // loops with memory access that will be executed in every iteration. | ||
| const SCEV *PotentiallyBetterConstantMax = getCouldNotCompute(); | ||
| if (UseMemoryAccessUBForBEInference && Latch == L->getExitingBlock()) { | ||
| assert(Latch == ExitingBlock); | ||
| auto EL = computeExitLimitFromMemAccess(L); | ||
| PotentiallyBetterConstantMax = EL.ConstantMaxNotTaken; | ||
| } | ||
|
|
||
| bool IsOnlyExit = (L->getExitingBlock() != nullptr); | ||
| Instruction *Term = ExitingBlock->getTerminator(); | ||
| if (BranchInst *BI = dyn_cast<BranchInst>(Term)) { | ||
|
|
@@ -8850,9 +9068,14 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, | |
| assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) && | ||
| "It should have one successor in loop and one exit block!"); | ||
| // Proceed to the next level to examine the exit condition expression. | ||
| return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue, | ||
| /*ControlsOnlyExit=*/IsOnlyExit, | ||
| AllowPredicates); | ||
| ExitLimit EL = computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue, | ||
| /*ControlsOnlyExit=*/IsOnlyExit, | ||
| AllowPredicates); | ||
| if (!isa<SCEVCouldNotCompute>(EL.ConstantMaxNotTaken) && | ||
| !isa<SCEVCouldNotCompute>(PotentiallyBetterConstantMax)) | ||
| EL.ConstantMaxNotTaken = getUMinFromMismatchedTypes( | ||
| EL.ConstantMaxNotTaken, PotentiallyBetterConstantMax); | ||
| return EL; | ||
| } | ||
|
|
||
| if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) { | ||
|
|
@@ -8865,9 +9088,14 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, | |
| Exit = SBB; | ||
| } | ||
| assert(Exit && "Exiting block must have at least one exit"); | ||
| return computeExitLimitFromSingleExitSwitch( | ||
| L, SI, Exit, | ||
| /*ControlsOnlyExit=*/IsOnlyExit); | ||
| ExitLimit EL = | ||
| computeExitLimitFromSingleExitSwitch(L, SI, Exit, | ||
| /*ControlsOnlyExit=*/IsOnlyExit); | ||
| if (!isa<SCEVCouldNotCompute>(EL.ConstantMaxNotTaken) && | ||
| !isa<SCEVCouldNotCompute>(PotentiallyBetterConstantMax)) | ||
| EL.ConstantMaxNotTaken = getUMinFromMismatchedTypes( | ||
| EL.ConstantMaxNotTaken, PotentiallyBetterConstantMax); | ||
| return EL; | ||
| } | ||
|
|
||
| return getCouldNotCompute(); | ||
|
|
@@ -8881,6 +9109,13 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond( | |
| ControlsOnlyExit, AllowPredicates); | ||
| } | ||
|
|
||
| ScalarEvolution::ExitLimit | ||
| ScalarEvolution::computeExitLimitFromMemAccess(const Loop *L) { | ||
| ScalarEvolution::ExitLimitCacheTy Cache(L, /* ExitIfTrue */ true, | ||
| /* AllowPredicates */ true); | ||
| return computeExitLimitFromMemAccessCached(Cache, L); | ||
| } | ||
|
|
||
| std::optional<ScalarEvolution::ExitLimit> | ||
| ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond, | ||
| bool ExitIfTrue, bool ControlsOnlyExit, | ||
|
|
@@ -13579,6 +13814,17 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, | |
| OS << ": "; | ||
| OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n"; | ||
| } | ||
|
|
||
| if (UseMemoryAccessUBForBEInference) { | ||
| unsigned SmallMaxTrip = SE->getSmallConstantMaxTripCount(L); | ||
| OS << "Loop "; | ||
| L->getHeader()->printAsOperand(OS, /*PrintType=*/false); | ||
| OS << ": "; | ||
| if (SmallMaxTrip) | ||
| OS << "Small constant max trip is " << SmallMaxTrip << "\n"; | ||
| else | ||
| OS << "Small constant max trip couldn't be computed.\n"; | ||
| } | ||
| } | ||
|
|
||
| namespace llvm { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.