@@ -249,6 +249,10 @@ static cl::opt<bool> UseContextForNoWrapFlagInference(
249249 cl::desc("Infer nuw/nsw flags using context where suitable"),
250250 cl::init(true));
251251
252+ static cl::opt<bool> UseMemoryAccessUBForBEInference(
253+ "scalar-evolution-infer-max-trip-count-from-memory-access", cl::Hidden,
254+ cl::desc("Infer loop max trip count from memory access"), cl::init(false));
255+
252256//===----------------------------------------------------------------------===//
253257// SCEV class definitions
254258//===----------------------------------------------------------------------===//
@@ -8253,6 +8257,190 @@ ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
82538257 return getSmallConstantTripMultiple(L, ExitCount);
82548258}
82558259
8260+ /// Collect all load/store instructions that must be executed in every iteration
8261+ /// of loop \p L .
8262+ static void
8263+ collectExecLoadStoreInsideLoop(const Loop *L, DominatorTree &DT,
8264+ SmallVector<Instruction *, 4> &MemInsts) {
8265+ // It is difficult to tell if the load/store instruction is executed on every
8266+ // iteration inside an irregular loop.
8267+ if (!L->isLoopSimplifyForm() || !L->isInnermost())
8268+ return;
8269+
8270+ const BasicBlock *LoopLatch = L->getLoopLatch();
8271+ assert(LoopLatch && "normal form loop doesn't have a latch");
8272+ assert(L->getExitingBlock() == LoopLatch);
8273+
8274+ // We will not continue if sanitizer is enabled.
8275+ const Function *F = LoopLatch->getParent();
8276+ if (F->hasFnAttribute(Attribute::SanitizeAddress) ||
8277+ F->hasFnAttribute(Attribute::SanitizeThread) ||
8278+ F->hasFnAttribute(Attribute::SanitizeMemory) ||
8279+ F->hasFnAttribute(Attribute::SanitizeHWAddress) ||
8280+ F->hasFnAttribute(Attribute::SanitizeMemTag))
8281+ return;
8282+
8283+ for (auto *BB : L->getBlocks()) {
8284+ // We need to make sure that max execution time of MemAccessBB in loop
8285+ // represents latch max excution time. The BB below should be skipped:
8286+ // Entry
8287+ // │
8288+ // ┌─────▼─────┐
8289+ // │Loop Header◄─────┐
8290+ // └──┬──────┬─┘ │
8291+ // │ │ │
8292+ // ┌────────▼──┐ ┌─▼─────┐ │
8293+ // │MemAccessBB│ │OtherBB│ │
8294+ // └────────┬──┘ └─┬─────┘ │
8295+ // │ │ │
8296+ // ┌─▼──────▼─┐ │
8297+ // │Loop Latch├─────┘
8298+ // └────┬─────┘
8299+ // ▼
8300+ // Exit
8301+ if (!DT.dominates(BB, LoopLatch))
8302+ continue;
8303+
8304+ for (Instruction &I : *BB) {
8305+ if (isa<LoadInst>(&I) || isa<StoreInst>(&I))
8306+ MemInsts.push_back(&I);
8307+ }
8308+ }
8309+ }
8310+
8311+ /// Return a SCEV representing the memory size of pointer \p V .
8312+ static const SCEV *getCertainSizeOfMem(const SCEV *V, Type *RTy,
8313+ const DataLayout &DL,
8314+ const TargetLibraryInfo &TLI,
8315+ ScalarEvolution *SE) {
8316+ const SCEVUnknown *PtrBase = dyn_cast<SCEVUnknown>(V);
8317+ if (!PtrBase)
8318+ return nullptr;
8319+ Value *Ptr = PtrBase->getValue();
8320+ uint64_t Size = 0;
8321+ if (!llvm::getObjectSize(Ptr, Size, DL, &TLI))
8322+ return nullptr;
8323+ return SE->getConstant(RTy, Size);
8324+ }
8325+
8326+ /// Get the range of given index represented by \p AddRec.
8327+ static const SCEV *getIndexRange(const SCEVAddRecExpr *AddRec,
8328+ ScalarEvolution *SE) {
8329+ const SCEV *Range = SE->getConstant(SE->getUnsignedRangeMax(AddRec) -
8330+ SE->getUnsignedRangeMin(AddRec));
8331+ const SCEV *Step = AddRec->getStepRecurrence(*SE);
8332+ return SE->getUDivCeilSCEV(Range, Step);
8333+ }
8334+
8335+ /// Get the underlying SCEVAddExpr from a cast expression if possible.
8336+ const SCEV *peelCastExpr(const SCEVCastExpr *S, ScalarEvolution *SE) {
8337+ const SCEV *Op = S->getOperand();
8338+ if (isa<SCEVCouldNotCompute>(Op))
8339+ return Op;
8340+ if (isa<SCEVAddExpr>(Op))
8341+ return Op;
8342+ if (isa<SCEVCastExpr>(Op))
8343+ return peelCastExpr(cast<SCEVCastExpr>(Op), SE);
8344+ return SE->getCouldNotCompute();
8345+ }
8346+
8347+ /// Check whether the index can wrap and if we can still infer max trip count
8348+ /// given the max trip count inferred from memory access.
8349+ static const SCEV *checkIndexRange(Value *Ptr, ScalarEvolution *SE,
8350+ const SCEVConstant *MaxExecCount) {
8351+ SmallVector<const SCEV *> InferCountColl;
8352+ auto *PtrGEP = dyn_cast<GetElementPtrInst>(Ptr);
8353+ if (!PtrGEP)
8354+ return SE->getCouldNotCompute();
8355+ for (Value *Index : PtrGEP->indices()) {
8356+ Value *V = Index;
8357+ if (isa<ZExtInst>(V) || isa<SExtInst>(V))
8358+ V = cast<Instruction>(Index)->getOperand(0);
8359+ auto *SCEV = SE->getSCEV(V);
8360+ if (isa<SCEVCouldNotCompute>(SCEV))
8361+ return SE->getCouldNotCompute();
8362+ if (isa<SCEVConstant>(SCEV))
8363+ continue;
8364+ if (isa<SCEVCastExpr>(SCEV)) {
8365+ SCEV = peelCastExpr(cast<SCEVCastExpr>(SCEV), SE);
8366+ if (isa<SCEVCouldNotCompute>(SCEV))
8367+ return SE->getCouldNotCompute();
8368+ }
8369+ auto *AddRec = dyn_cast<SCEVAddRecExpr>(SCEV);
8370+ if (!AddRec)
8371+ return SE->getCouldNotCompute();
8372+ auto *IndexRange = getIndexRange(AddRec, SE);
8373+ auto *IndexRangeC = dyn_cast<SCEVConstant>(IndexRange);
8374+ if (!IndexRangeC)
8375+ return SE->getCouldNotCompute();
8376+ InferCountColl.push_back(IndexRange);
8377+ break;
8378+ }
8379+
8380+ if (InferCountColl.empty())
8381+ return SE->getCouldNotCompute();
8382+
8383+ InferCountColl.push_back(MaxExecCount);
8384+
8385+ return SE->getUMinFromMismatchedTypes(InferCountColl);
8386+ }
8387+
8388+ ScalarEvolution::ExitLimit
8389+ ScalarEvolution::computeExitLimitFromMemAccessImpl(const Loop *L) {
8390+ SmallVector<Instruction *, 4> MemInsts;
8391+ collectExecLoadStoreInsideLoop(L, DT, MemInsts);
8392+
8393+ SmallVector<const SCEV *> InferCountColl;
8394+ const DataLayout &DL = getDataLayout();
8395+
8396+ for (Instruction *I : MemInsts) {
8397+ Value *Ptr = getLoadStorePointerOperand(I);
8398+ assert(Ptr && "empty pointer operand");
8399+ auto *AddRec = dyn_cast<SCEVAddRecExpr>(getSCEV(Ptr));
8400+ if (!AddRec || !AddRec->isAffine())
8401+ continue;
8402+ const SCEV *PtrBase = getPointerBase(AddRec);
8403+ const SCEV *Step = AddRec->getStepRecurrence(*this);
8404+ const SCEV *MemSize =
8405+ getCertainSizeOfMem(PtrBase, Step->getType(), DL, TLI, this);
8406+ if (!MemSize)
8407+ continue;
8408+ // Now we can infer a max execution time by MemLength/StepLength.
8409+ auto *MaxExecCount = dyn_cast<SCEVConstant>(getUDivCeilSCEV(MemSize, Step));
8410+ if (!MaxExecCount || MaxExecCount->getAPInt().getActiveBits() > 32)
8411+ continue;
8412+ auto *Res = checkIndexRange(Ptr, this, MaxExecCount);
8413+ if (isa<SCEVCouldNotCompute>(Res))
8414+ continue;
8415+ InferCountColl.push_back(Res);
8416+ }
8417+
8418+ if (InferCountColl.empty())
8419+ return getCouldNotCompute();
8420+
8421+ const SCEV *Count = getUMinFromMismatchedTypes(InferCountColl);
8422+
8423+ return {getCouldNotCompute(), Count, Count, /*MaxOrZero=*/false};
8424+ }
8425+
8426+ ScalarEvolution::ExitLimit
8427+ ScalarEvolution::computeExitLimitFromMemAccessCached(ExitLimitCacheTy &Cache,
8428+ const Loop *L) {
8429+ // We don't really need them but the cache does.
8430+ constexpr Value *ExitCond = nullptr;
8431+ constexpr const bool ExitIfTrue = true;
8432+ constexpr const bool ControlsOnlyExit = true;
8433+ constexpr const bool AllowPredicates = true;
8434+
8435+ if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
8436+ AllowPredicates))
8437+ return *MaybeEL;
8438+
8439+ ExitLimit EL = computeExitLimitFromMemAccessImpl(L);
8440+ Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
8441+ return EL;
8442+ }
8443+
82568444const SCEV *ScalarEvolution::getExitCount(const Loop *L,
82578445 const BasicBlock *ExitingBlock,
82588446 ExitCountKind Kind) {
@@ -8835,6 +9023,16 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
88359023 if (!Latch || !DT.dominates(ExitingBlock, Latch))
88369024 return getCouldNotCompute();
88379025
9026+ // FIXME: To make the case more typical, we only analyze loops that have one
9027+ // exiting block and the block must be the latch. It is easier to capture
9028+ // loops with memory access that will be executed in every iteration.
9029+ const SCEV *PotentiallyBetterConstantMax = getCouldNotCompute();
9030+ if (UseMemoryAccessUBForBEInference && Latch == L->getExitingBlock()) {
9031+ assert(Latch == ExitingBlock);
9032+ auto EL = computeExitLimitFromMemAccess(L);
9033+ PotentiallyBetterConstantMax = EL.ConstantMaxNotTaken;
9034+ }
9035+
88389036 bool IsOnlyExit = (L->getExitingBlock() != nullptr);
88399037 Instruction *Term = ExitingBlock->getTerminator();
88409038 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
@@ -8843,9 +9041,14 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
88439041 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
88449042 "It should have one successor in loop and one exit block!");
88459043 // Proceed to the next level to examine the exit condition expression.
8846- return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
8847- /*ControlsOnlyExit=*/IsOnlyExit,
8848- AllowPredicates);
9044+ ExitLimit EL = computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
9045+ /*ControlsOnlyExit=*/IsOnlyExit,
9046+ AllowPredicates);
9047+ if (!isa<SCEVCouldNotCompute>(EL.ConstantMaxNotTaken) &&
9048+ !isa<SCEVCouldNotCompute>(PotentiallyBetterConstantMax))
9049+ EL.ConstantMaxNotTaken = getUMinFromMismatchedTypes(
9050+ EL.ConstantMaxNotTaken, PotentiallyBetterConstantMax);
9051+ return EL;
88499052 }
88509053
88519054 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
@@ -8858,9 +9061,14 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
88589061 Exit = SBB;
88599062 }
88609063 assert(Exit && "Exiting block must have at least one exit");
8861- return computeExitLimitFromSingleExitSwitch(
8862- L, SI, Exit,
8863- /*ControlsOnlyExit=*/IsOnlyExit);
9064+ ExitLimit EL =
9065+ computeExitLimitFromSingleExitSwitch(L, SI, Exit,
9066+ /*ControlsOnlyExit=*/IsOnlyExit);
9067+ if (!isa<SCEVCouldNotCompute>(EL.ConstantMaxNotTaken) &&
9068+ !isa<SCEVCouldNotCompute>(PotentiallyBetterConstantMax))
9069+ EL.ConstantMaxNotTaken = getUMinFromMismatchedTypes(
9070+ EL.ConstantMaxNotTaken, PotentiallyBetterConstantMax);
9071+ return EL;
88649072 }
88659073
88669074 return getCouldNotCompute();
@@ -8874,6 +9082,13 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond(
88749082 ControlsOnlyExit, AllowPredicates);
88759083}
88769084
9085+ ScalarEvolution::ExitLimit
9086+ ScalarEvolution::computeExitLimitFromMemAccess(const Loop *L) {
9087+ ScalarEvolution::ExitLimitCacheTy Cache(L, /* ExitIfTrue */ true,
9088+ /* AllowPredicates */ true);
9089+ return computeExitLimitFromMemAccessCached(Cache, L);
9090+ }
9091+
88779092std::optional<ScalarEvolution::ExitLimit>
88789093ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
88799094 bool ExitIfTrue, bool ControlsOnlyExit,
@@ -13544,6 +13759,17 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
1354413759 OS << ": ";
1354513760 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
1354613761 }
13762+
13763+ if (UseMemoryAccessUBForBEInference) {
13764+ unsigned SmallMaxTrip = SE->getSmallConstantMaxTripCount(L);
13765+ OS << "Loop ";
13766+ L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13767+ OS << ": ";
13768+ if (SmallMaxTrip)
13769+ OS << "Small constant max trip is " << SmallMaxTrip << "\n";
13770+ else
13771+ OS << "Small constant max trip couldn't be computed.\n";
13772+ }
1354713773}
1354813774
1354913775namespace llvm {
0 commit comments