@@ -8139,16 +8139,7 @@ ScalarEvolution::getSmallConstantTripCount(const Loop *L,
81398139unsigned ScalarEvolution::getSmallConstantMaxTripCount(const Loop *L) {
81408140 const auto *MaxExitCount =
81418141 dyn_cast<SCEVConstant>(getConstantMaxBackedgeTakenCount(L));
8142- unsigned MaxExitCountN = getConstantTripCount(MaxExitCount);
8143- if (UseMemoryAccessUBForBEInference) {
8144- auto *MaxInferCount = getConstantMaxTripCountFromMemAccess(L);
8145- if (auto *InferCount = dyn_cast<SCEVConstant>(MaxInferCount)) {
8146- unsigned InferValue = InferCount->getValue()->getZExtValue();
8147- MaxExitCountN =
8148- MaxExitCountN == 0 ? InferValue : std::min(MaxExitCountN, InferValue);
8149- }
8150- }
8151- return MaxExitCountN;
8142+ return getConstantTripCount(MaxExitCount);
81528143}
81538144
81548145unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L) {
@@ -8213,13 +8204,9 @@ collectExecLoadStoreInsideLoop(const Loop *L, DominatorTree &DT,
82138204 if (!L->isLoopSimplifyForm() || !L->isInnermost())
82148205 return;
82158206
8216- // FIXME: To make the case more typical, we only analyze loops that have one
8217- // exiting block and the block must be the latch. It is easier to capture
8218- // loops with memory access that will be executed in every iteration.
82198207 const BasicBlock *LoopLatch = L->getLoopLatch();
82208208 assert(LoopLatch && "normal form loop doesn't have a latch");
8221- if (L->getExitingBlock() != LoopLatch)
8222- return;
8209+ assert(L->getExitingBlock() == LoopLatch);
82238210
82248211 // We will not continue if sanitizer is enabled.
82258212 const Function *F = LoopLatch->getParent();
@@ -8322,8 +8309,8 @@ static const SCEV *checkIndexWrap(Value *Ptr, ScalarEvolution *SE,
83228309 return SE->getUMinFromMismatchedTypes(InferCountColl);
83238310}
83248311
8325- const SCEV *
8326- ScalarEvolution::getConstantMaxTripCountFromMemAccess (const Loop *L) {
8312+ ScalarEvolution::ExitLimit
8313+ ScalarEvolution::computeExitLimitFromMemAccessImpl (const Loop *L) {
83278314 SmallVector<Instruction *, 4> MemInsts;
83288315 collectExecLoadStoreInsideLoop(L, DT, MemInsts);
83298316
@@ -8361,7 +8348,27 @@ ScalarEvolution::getConstantMaxTripCountFromMemAccess(const Loop *L) {
83618348 if (InferCountColl.empty())
83628349 return getCouldNotCompute();
83638350
8364- return getUMinFromMismatchedTypes(InferCountColl);
8351+ const SCEV *Count = getUMinFromMismatchedTypes(InferCountColl);
8352+
8353+ return {getCouldNotCompute(), Count, Count, /*MaxOrZero=*/false};
8354+ }
8355+
8356+ ScalarEvolution::ExitLimit
8357+ ScalarEvolution::computeExitLimitFromMemAccessCached(ExitLimitCacheTy &Cache,
8358+ const Loop *L) {
8359+ // We don't really need them but the cache does.
8360+ constexpr Value *ExitCond = nullptr;
8361+ constexpr const bool ExitIfTrue = true;
8362+ constexpr const bool ControlsOnlyExit = true;
8363+ constexpr const bool AllowPredicates = true;
8364+
8365+ if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
8366+ AllowPredicates))
8367+ return *MaybeEL;
8368+
8369+ ExitLimit EL = computeExitLimitFromMemAccessImpl(L);
8370+ Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
8371+ return EL;
83658372}
83668373
83678374const SCEV *ScalarEvolution::getExitCount(const Loop *L,
@@ -8946,6 +8953,16 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
89468953 if (!Latch || !DT.dominates(ExitingBlock, Latch))
89478954 return getCouldNotCompute();
89488955
8956+ // FIXME: To make the case more typical, we only analyze loops that have one
8957+ // exiting block and the block must be the latch. It is easier to capture
8958+ // loops with memory access that will be executed in every iteration.
8959+ const SCEV *PotentiallyBetterConstantMax = getCouldNotCompute();
8960+ if (UseMemoryAccessUBForBEInference && Latch == L->getExitingBlock()) {
8961+ assert(Latch == ExitingBlock);
8962+ auto EL = computeExitLimitFromMemAccess(L);
8963+ PotentiallyBetterConstantMax = EL.ConstantMaxNotTaken;
8964+ }
8965+
89498966 bool IsOnlyExit = (L->getExitingBlock() != nullptr);
89508967 Instruction *Term = ExitingBlock->getTerminator();
89518968 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
@@ -8954,9 +8971,13 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
89548971 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
89558972 "It should have one successor in loop and one exit block!");
89568973 // Proceed to the next level to examine the exit condition expression.
8957- return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
8958- /*ControlsOnlyExit=*/IsOnlyExit,
8959- AllowPredicates);
8974+ ExitLimit EL = computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
8975+ /*ControlsOnlyExit=*/IsOnlyExit,
8976+ AllowPredicates);
8977+ if (!isa<SCEVCouldNotCompute>(PotentiallyBetterConstantMax))
8978+ EL.ConstantMaxNotTaken = getUMinFromMismatchedTypes(
8979+ EL.ConstantMaxNotTaken, PotentiallyBetterConstantMax);
8980+ return EL;
89608981 }
89618982
89628983 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
@@ -8969,9 +8990,13 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
89698990 Exit = SBB;
89708991 }
89718992 assert(Exit && "Exiting block must have at least one exit");
8972- return computeExitLimitFromSingleExitSwitch(
8973- L, SI, Exit,
8974- /*ControlsOnlyExit=*/IsOnlyExit);
8993+ ExitLimit EL =
8994+ computeExitLimitFromSingleExitSwitch(L, SI, Exit,
8995+ /*ControlsOnlyExit=*/IsOnlyExit);
8996+ if (!isa<SCEVCouldNotCompute>(PotentiallyBetterConstantMax))
8997+ EL.ConstantMaxNotTaken = getUMinFromMismatchedTypes(
8998+ EL.ConstantMaxNotTaken, PotentiallyBetterConstantMax);
8999+ return EL;
89759000 }
89769001
89779002 return getCouldNotCompute();
@@ -8985,6 +9010,13 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond(
89859010 ControlsOnlyExit, AllowPredicates);
89869011}
89879012
9013+ ScalarEvolution::ExitLimit
9014+ ScalarEvolution::computeExitLimitFromMemAccess(const Loop *L) {
9015+ ScalarEvolution::ExitLimitCacheTy Cache(L, /* ExitIfTrue */ true,
9016+ /* AllowPredicates */ true);
9017+ return computeExitLimitFromMemAccessCached(Cache, L);
9018+ }
9019+
89889020std::optional<ScalarEvolution::ExitLimit>
89899021ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
89909022 bool ExitIfTrue, bool ControlsOnlyExit,
0 commit comments