@@ -249,6 +249,10 @@ static cl::opt<bool> UseContextForNoWrapFlagInference(
249249 cl::desc("Infer nuw/nsw flags using context where suitable"),
250250 cl::init(true));
251251
252+ static cl::opt<bool> UseMemoryAccessUBForBEInference(
253+ "scalar-evolution-infer-max-trip-count-from-memory-access", cl::Hidden,
254+ cl::desc("Infer loop max trip count from memory access"), cl::init(false));
255+
252256//===----------------------------------------------------------------------===//
253257// SCEV class definitions
254258//===----------------------------------------------------------------------===//
@@ -8253,6 +8257,194 @@ ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
82538257 return getSmallConstantTripMultiple(L, ExitCount);
82548258}
82558259
8260+ /// Collect all load/store instructions that must be executed in every iteration
8261+ /// of loop \p L .
8262+ static void
8263+ collectExecLoadStoreInsideLoop(const Loop *L, DominatorTree &DT,
8264+ SmallVector<Instruction *, 4> &MemInsts) {
8265+ // It is difficult to tell if the load/store instruction is executed on every
8266+ // iteration inside an irregular loop.
8267+ if (!L->isLoopSimplifyForm() || !L->isInnermost())
8268+ return;
8269+
8270+ const BasicBlock *LoopLatch = L->getLoopLatch();
8271+ assert(LoopLatch && "normal form loop doesn't have a latch");
8272+ assert(L->getExitingBlock() == LoopLatch);
8273+
8274+ // We will not continue if sanitizer is enabled.
8275+ const Function *F = LoopLatch->getParent();
8276+ if (F->hasFnAttribute(Attribute::SanitizeAddress) ||
8277+ F->hasFnAttribute(Attribute::SanitizeThread) ||
8278+ F->hasFnAttribute(Attribute::SanitizeMemory) ||
8279+ F->hasFnAttribute(Attribute::SanitizeHWAddress) ||
8280+ F->hasFnAttribute(Attribute::SanitizeMemTag))
8281+ return;
8282+
8283+ for (auto *BB : L->getBlocks()) {
8284+ // We need to make sure that max execution time of MemAccessBB in loop
8285+ // represents latch max excution time. The BB below should be skipped:
8286+ // Entry
8287+ // │
8288+ // ┌─────▼─────┐
8289+ // │Loop Header◄─────┐
8290+ // └──┬──────┬─┘ │
8291+ // │ │ │
8292+ // ┌────────▼──┐ ┌─▼─────┐ │
8293+ // │MemAccessBB│ │OtherBB│ │
8294+ // └────────┬──┘ └─┬─────┘ │
8295+ // │ │ │
8296+ // ┌─▼──────▼─┐ │
8297+ // │Loop Latch├─────┘
8298+ // └────┬─────┘
8299+ // ▼
8300+ // Exit
8301+ if (!DT.dominates(BB, LoopLatch))
8302+ continue;
8303+
8304+ for (Instruction &I : *BB) {
8305+ if (isa<LoadInst>(&I) || isa<StoreInst>(&I))
8306+ MemInsts.push_back(&I);
8307+ }
8308+ }
8309+ }
8310+
8311+ /// Return a SCEV representing the memory size of pointer \p V .
8312+ static const SCEV *getCertainSizeOfMem(const SCEV *V, Type *RTy,
8313+ const DataLayout &DL,
8314+ const TargetLibraryInfo &TLI,
8315+ ScalarEvolution *SE) {
8316+ const SCEVUnknown *PtrBase = dyn_cast<SCEVUnknown>(V);
8317+ if (!PtrBase)
8318+ return nullptr;
8319+ Value *Ptr = PtrBase->getValue();
8320+ uint64_t Size = 0;
8321+ if (!llvm::getObjectSize(Ptr, Size, DL, &TLI))
8322+ return nullptr;
8323+ return SE->getConstant(RTy, Size);
8324+ }
8325+
8326+ /// Get the range of given index represented by \p AddRec.
8327+ static const SCEV *getIndexRange(const SCEVAddRecExpr *AddRec,
8328+ ScalarEvolution *SE) {
8329+ const SCEV *Range = SE->getConstant(SE->getUnsignedRangeMax(AddRec) -
8330+ SE->getUnsignedRangeMin(AddRec));
8331+ const SCEV *Step = AddRec->getStepRecurrence(*SE);
8332+ if (SE->isKnownNegative(Step))
8333+ Step = SE->getNegativeSCEV(Step);
8334+ return SE->getUDivCeilSCEV(Range, Step);
8335+ }
8336+
8337+ /// Get the underlying SCEVAddExpr from a cast expression if possible.
8338+ const SCEV *peelCastExpr(const SCEVCastExpr *S, ScalarEvolution *SE) {
8339+ const SCEV *Op = S->getOperand();
8340+ if (isa<SCEVCouldNotCompute>(Op))
8341+ return Op;
8342+ if (isa<SCEVAddExpr>(Op))
8343+ return Op;
8344+ if (isa<SCEVCastExpr>(Op))
8345+ return peelCastExpr(cast<SCEVCastExpr>(Op), SE);
8346+ return SE->getCouldNotCompute();
8347+ }
8348+
8349+ /// Check whether the index can wrap and if we can still infer max trip count
8350+ /// given the max trip count inferred from memory access.
8351+ static const SCEV *checkIndexRange(Value *Ptr, ScalarEvolution *SE,
8352+ const SCEVConstant *MaxExecCount) {
8353+ SmallVector<const SCEV *> InferCountColl;
8354+ auto *PtrGEP = dyn_cast<GetElementPtrInst>(Ptr);
8355+ if (!PtrGEP)
8356+ return SE->getCouldNotCompute();
8357+ for (Value *Index : PtrGEP->indices()) {
8358+ Value *V = Index;
8359+ if (isa<ZExtInst>(V) || isa<SExtInst>(V))
8360+ V = cast<Instruction>(Index)->getOperand(0);
8361+ auto *SCEV = SE->getSCEV(V);
8362+ if (isa<SCEVCouldNotCompute>(SCEV))
8363+ return SE->getCouldNotCompute();
8364+ if (isa<SCEVConstant>(SCEV))
8365+ continue;
8366+ if (isa<SCEVCastExpr>(SCEV)) {
8367+ SCEV = peelCastExpr(cast<SCEVCastExpr>(SCEV), SE);
8368+ if (isa<SCEVCouldNotCompute>(SCEV))
8369+ return SE->getCouldNotCompute();
8370+ }
8371+ auto *AddRec = dyn_cast<SCEVAddRecExpr>(SCEV);
8372+ if (!AddRec)
8373+ return SE->getCouldNotCompute();
8374+ auto *IndexRange = getIndexRange(AddRec, SE);
8375+ auto *IndexRangeC = dyn_cast<SCEVConstant>(IndexRange);
8376+ if (!IndexRangeC)
8377+ return SE->getCouldNotCompute();
8378+ InferCountColl.push_back(IndexRange);
8379+ break;
8380+ }
8381+
8382+ if (InferCountColl.empty())
8383+ return SE->getCouldNotCompute();
8384+
8385+ InferCountColl.push_back(MaxExecCount);
8386+
8387+ return SE->getUMinFromMismatchedTypes(InferCountColl);
8388+ }
8389+
8390+ ScalarEvolution::ExitLimit
8391+ ScalarEvolution::computeExitLimitFromMemAccessImpl(const Loop *L) {
8392+ SmallVector<Instruction *, 4> MemInsts;
8393+ collectExecLoadStoreInsideLoop(L, DT, MemInsts);
8394+
8395+ SmallVector<const SCEV *> InferCountColl;
8396+ const DataLayout &DL = getDataLayout();
8397+
8398+ for (Instruction *I : MemInsts) {
8399+ Value *Ptr = getLoadStorePointerOperand(I);
8400+ assert(Ptr && "empty pointer operand");
8401+ auto *AddRec = dyn_cast<SCEVAddRecExpr>(getSCEV(Ptr));
8402+ if (!AddRec || !AddRec->isAffine())
8403+ continue;
8404+ const SCEV *PtrBase = getPointerBase(AddRec);
8405+ const SCEV *Step = AddRec->getStepRecurrence(*this);
8406+ const SCEV *MemSize =
8407+ getCertainSizeOfMem(PtrBase, Step->getType(), DL, TLI, this);
8408+ if (!MemSize)
8409+ continue;
8410+ if (isKnownNegative(Step))
8411+ Step = getNegativeSCEV(Step);
8412+ // Now we can infer a max execution time by MemLength/StepLength.
8413+ auto *MaxExecCount = dyn_cast<SCEVConstant>(getUDivCeilSCEV(MemSize, Step));
8414+ if (!MaxExecCount || MaxExecCount->getAPInt().getActiveBits() > 32)
8415+ continue;
8416+ auto *Res = checkIndexRange(Ptr, this, MaxExecCount);
8417+ if (isa<SCEVCouldNotCompute>(Res))
8418+ continue;
8419+ InferCountColl.push_back(Res);
8420+ }
8421+
8422+ if (InferCountColl.empty())
8423+ return getCouldNotCompute();
8424+
8425+ const SCEV *Count = getUMinFromMismatchedTypes(InferCountColl);
8426+
8427+ return {getCouldNotCompute(), Count, Count, /*MaxOrZero=*/false};
8428+ }
8429+
8430+ ScalarEvolution::ExitLimit
8431+ ScalarEvolution::computeExitLimitFromMemAccessCached(ExitLimitCacheTy &Cache,
8432+ const Loop *L) {
8433+ // We don't really need them but the cache does.
8434+ constexpr Value *ExitCond = nullptr;
8435+ constexpr const bool ExitIfTrue = true;
8436+ constexpr const bool ControlsOnlyExit = true;
8437+ constexpr const bool AllowPredicates = true;
8438+
8439+ if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
8440+ AllowPredicates))
8441+ return *MaybeEL;
8442+
8443+ ExitLimit EL = computeExitLimitFromMemAccessImpl(L);
8444+ Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
8445+ return EL;
8446+ }
8447+
82568448const SCEV *ScalarEvolution::getExitCount(const Loop *L,
82578449 const BasicBlock *ExitingBlock,
82588450 ExitCountKind Kind) {
@@ -8835,6 +9027,16 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
88359027 if (!Latch || !DT.dominates(ExitingBlock, Latch))
88369028 return getCouldNotCompute();
88379029
9030+ // FIXME: To make the case more typical, we only analyze loops that have one
9031+ // exiting block and the block must be the latch. It is easier to capture
9032+ // loops with memory access that will be executed in every iteration.
9033+ const SCEV *PotentiallyBetterConstantMax = getCouldNotCompute();
9034+ if (UseMemoryAccessUBForBEInference && Latch == L->getExitingBlock()) {
9035+ assert(Latch == ExitingBlock);
9036+ auto EL = computeExitLimitFromMemAccess(L);
9037+ PotentiallyBetterConstantMax = EL.ConstantMaxNotTaken;
9038+ }
9039+
88389040 bool IsOnlyExit = (L->getExitingBlock() != nullptr);
88399041 Instruction *Term = ExitingBlock->getTerminator();
88409042 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
@@ -8843,9 +9045,14 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
88439045 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
88449046 "It should have one successor in loop and one exit block!");
88459047 // Proceed to the next level to examine the exit condition expression.
8846- return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
8847- /*ControlsOnlyExit=*/IsOnlyExit,
8848- AllowPredicates);
9048+ ExitLimit EL = computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
9049+ /*ControlsOnlyExit=*/IsOnlyExit,
9050+ AllowPredicates);
9051+ if (!isa<SCEVCouldNotCompute>(EL.ConstantMaxNotTaken) &&
9052+ !isa<SCEVCouldNotCompute>(PotentiallyBetterConstantMax))
9053+ EL.ConstantMaxNotTaken = getUMinFromMismatchedTypes(
9054+ EL.ConstantMaxNotTaken, PotentiallyBetterConstantMax);
9055+ return EL;
88499056 }
88509057
88519058 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
@@ -8858,9 +9065,14 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
88589065 Exit = SBB;
88599066 }
88609067 assert(Exit && "Exiting block must have at least one exit");
8861- return computeExitLimitFromSingleExitSwitch(
8862- L, SI, Exit,
8863- /*ControlsOnlyExit=*/IsOnlyExit);
9068+ ExitLimit EL =
9069+ computeExitLimitFromSingleExitSwitch(L, SI, Exit,
9070+ /*ControlsOnlyExit=*/IsOnlyExit);
9071+ if (!isa<SCEVCouldNotCompute>(EL.ConstantMaxNotTaken) &&
9072+ !isa<SCEVCouldNotCompute>(PotentiallyBetterConstantMax))
9073+ EL.ConstantMaxNotTaken = getUMinFromMismatchedTypes(
9074+ EL.ConstantMaxNotTaken, PotentiallyBetterConstantMax);
9075+ return EL;
88649076 }
88659077
88669078 return getCouldNotCompute();
@@ -8874,6 +9086,13 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond(
88749086 ControlsOnlyExit, AllowPredicates);
88759087}
88769088
9089+ ScalarEvolution::ExitLimit
9090+ ScalarEvolution::computeExitLimitFromMemAccess(const Loop *L) {
9091+ ScalarEvolution::ExitLimitCacheTy Cache(L, /* ExitIfTrue */ true,
9092+ /* AllowPredicates */ true);
9093+ return computeExitLimitFromMemAccessCached(Cache, L);
9094+ }
9095+
88779096std::optional<ScalarEvolution::ExitLimit>
88789097ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
88799098 bool ExitIfTrue, bool ControlsOnlyExit,
@@ -13544,6 +13763,17 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
1354413763 OS << ": ";
1354513764 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
1354613765 }
13766+
13767+ if (UseMemoryAccessUBForBEInference) {
13768+ unsigned SmallMaxTrip = SE->getSmallConstantMaxTripCount(L);
13769+ OS << "Loop ";
13770+ L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13771+ OS << ": ";
13772+ if (SmallMaxTrip)
13773+ OS << "Small constant max trip is " << SmallMaxTrip << "\n";
13774+ else
13775+ OS << "Small constant max trip couldn't be computed.\n";
13776+ }
1354713777}
1354813778
1354913779namespace llvm {
0 commit comments