@@ -249,6 +249,10 @@ static cl::opt<bool> UseContextForNoWrapFlagInference(
249249 cl::desc("Infer nuw/nsw flags using context where suitable"),
250250 cl::init(true));
251251
252+ static cl::opt<bool> UseMemoryAccessUBForBEInference(
253+ "scalar-evolution-infer-max-trip-count-from-memory-access", cl::Hidden,
254+ cl::desc("Infer loop max trip count from memory access"), cl::init(false));
255+
252256//===----------------------------------------------------------------------===//
253257// SCEV class definitions
254258//===----------------------------------------------------------------------===//
@@ -8191,6 +8195,189 @@ ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
81918195 return getSmallConstantTripMultiple(L, ExitCount);
81928196}
81938197
8198+ /// Collect all load/store instructions that must be executed in every iteration
8199+ /// of loop \p L .
8200+ static void
8201+ collectExecLoadStoreInsideLoop(const Loop *L, DominatorTree &DT,
8202+ SmallVector<Instruction *, 4> &MemInsts) {
8203+ // It is difficult to tell if the load/store instruction is executed on every
8204+ // iteration inside an irregular loop.
8205+ if (!L->isLoopSimplifyForm() || !L->isInnermost())
8206+ return;
8207+
8208+ const BasicBlock *LoopLatch = L->getLoopLatch();
8209+ assert(LoopLatch && "normal form loop doesn't have a latch");
8210+ assert(L->getExitingBlock() == LoopLatch);
8211+
8212+ // We will not continue if sanitizer is enabled.
8213+ const Function *F = LoopLatch->getParent();
8214+ if (F->hasFnAttribute(Attribute::SanitizeAddress) ||
8215+ F->hasFnAttribute(Attribute::SanitizeThread) ||
8216+ F->hasFnAttribute(Attribute::SanitizeMemory) ||
8217+ F->hasFnAttribute(Attribute::SanitizeHWAddress) ||
8218+ F->hasFnAttribute(Attribute::SanitizeMemTag))
8219+ return;
8220+
8221+ for (auto *BB : L->getBlocks()) {
8222+ // We need to make sure that max execution time of MemAccessBB in loop
8223+ // represents latch max excution time. The BB below should be skipped:
8224+ // Entry
8225+ // │
8226+ // ┌─────▼─────┐
8227+ // │Loop Header◄─────┐
8228+ // └──┬──────┬─┘ │
8229+ // │ │ │
8230+ // ┌────────▼──┐ ┌─▼─────┐ │
8231+ // │MemAccessBB│ │OtherBB│ │
8232+ // └────────┬──┘ └─┬─────┘ │
8233+ // │ │ │
8234+ // ┌─▼──────▼─┐ │
8235+ // │Loop Latch├─────┘
8236+ // └────┬─────┘
8237+ // ▼
8238+ // Exit
8239+ if (!DT.dominates(BB, LoopLatch))
8240+ continue;
8241+
8242+ for (Instruction &I : *BB) {
8243+ if (isa<LoadInst>(&I) || isa<StoreInst>(&I))
8244+ MemInsts.push_back(&I);
8245+ }
8246+ }
8247+ }
8248+
8249+ /// Return a SCEV representing the memory size of pointer \p V .
8250+ static const SCEV *getCertainSizeOfMem(const SCEV *V, Type *RTy,
8251+ const DataLayout &DL,
8252+ const TargetLibraryInfo &TLI,
8253+ ScalarEvolution *SE) {
8254+ const SCEVUnknown *PtrBase = dyn_cast<SCEVUnknown>(V);
8255+ if (!PtrBase)
8256+ return nullptr;
8257+ Value *Ptr = PtrBase->getValue();
8258+ uint64_t Size = 0;
8259+ if (!llvm::getObjectSize(Ptr, Size, DL, &TLI))
8260+ return nullptr;
8261+ return SE->getConstant(RTy, Size);
8262+ }
8263+
8264+ /// Get the range of given index represented by \p AddRec.
8265+ static const SCEV *getIndexRange(const SCEVAddRecExpr *AddRec,
8266+ ScalarEvolution *SE) {
8267+ const SCEV *Range = SE->getConstant(SE->getUnsignedRangeMax(AddRec) -
8268+ SE->getUnsignedRangeMin(AddRec));
8269+ const SCEV *Step = AddRec->getStepRecurrence(*SE);
8270+ return SE->getUDivCeilSCEV(Range, Step);
8271+ }
8272+
8273+ /// Get the underlying SCEVAddExpr from a cast expression if possible.
8274+ const SCEV *peelCastExpr(const SCEVCastExpr *S, ScalarEvolution *SE) {
8275+ const SCEV *Op = S->getOperand();
8276+ if (isa<SCEVCouldNotCompute>(Op))
8277+ return Op;
8278+ if (isa<SCEVAddExpr>(Op))
8279+ return Op;
8280+ if (isa<SCEVCastExpr>(Op))
8281+ return peelCastExpr(cast<SCEVCastExpr>(Op), SE);
8282+ return SE->getCouldNotCompute();
8283+ }
8284+
8285+ /// Check whether the index can wrap and if we can still infer max trip count
8286+ /// given the max trip count inferred from memory access.
8287+ static const SCEV *checkIndexRange(Value *Ptr, ScalarEvolution *SE,
8288+ const SCEVConstant *MaxExecCount) {
8289+ SmallVector<const SCEV *> InferCountColl;
8290+ auto *PtrGEP = dyn_cast<GetElementPtrInst>(Ptr);
8291+ if (!PtrGEP)
8292+ return SE->getCouldNotCompute();
8293+ for (Value *Index : PtrGEP->indices()) {
8294+ Value *V = Index;
8295+ if (isa<ZExtInst>(V) || isa<SExtInst>(V))
8296+ V = cast<Instruction>(Index)->getOperand(0);
8297+ auto *SCEV = SE->getSCEV(V);
8298+ if (isa<SCEVCouldNotCompute>(SCEV))
8299+ return SE->getCouldNotCompute();
8300+ if (isa<SCEVConstant>(SCEV))
8301+ continue;
8302+ if (isa<SCEVCastExpr>(SCEV)) {
8303+ SCEV = peelCastExpr(cast<SCEVCastExpr>(SCEV), SE);
8304+ if (isa<SCEVCouldNotCompute>(SCEV))
8305+ return SE->getCouldNotCompute();
8306+ }
8307+ assert(isa<SCEVAddRecExpr>(SCEV));
8308+ auto *AddRec = cast<SCEVAddRecExpr>(SCEV);
8309+ auto *IndexRange = getIndexRange(AddRec, SE);
8310+ auto *IndexRangeC = dyn_cast<SCEVConstant>(IndexRange);
8311+ if (!IndexRangeC)
8312+ return SE->getCouldNotCompute();
8313+ InferCountColl.push_back(IndexRange);
8314+ break;
8315+ }
8316+
8317+ if (InferCountColl.empty())
8318+ return SE->getCouldNotCompute();
8319+
8320+ InferCountColl.push_back(MaxExecCount);
8321+
8322+ return SE->getUMinFromMismatchedTypes(InferCountColl);
8323+ }
8324+
8325+ ScalarEvolution::ExitLimit
8326+ ScalarEvolution::computeExitLimitFromMemAccessImpl(const Loop *L) {
8327+ SmallVector<Instruction *, 4> MemInsts;
8328+ collectExecLoadStoreInsideLoop(L, DT, MemInsts);
8329+
8330+ SmallVector<const SCEV *> InferCountColl;
8331+ const DataLayout &DL = getDataLayout();
8332+
8333+ for (Instruction *I : MemInsts) {
8334+ Value *Ptr = getLoadStorePointerOperand(I);
8335+ assert(Ptr && "empty pointer operand");
8336+ auto *AddRec = dyn_cast<SCEVAddRecExpr>(getSCEV(Ptr));
8337+ if (!AddRec || !AddRec->isAffine())
8338+ continue;
8339+ const SCEV *PtrBase = getPointerBase(AddRec);
8340+ const SCEV *Step = AddRec->getStepRecurrence(*this);
8341+ const SCEV *MemSize =
8342+ getCertainSizeOfMem(PtrBase, Step->getType(), DL, TLI, this);
8343+ if (!MemSize)
8344+ continue;
8345+ // Now we can infer a max execution time by MemLength/StepLength.
8346+ auto *MaxExecCount = dyn_cast<SCEVConstant>(getUDivCeilSCEV(MemSize, Step));
8347+ if (!MaxExecCount || MaxExecCount->getAPInt().getActiveBits() > 32)
8348+ continue;
8349+ auto *Res = checkIndexRange(Ptr, this, MaxExecCount);
8350+ if (isa<SCEVCouldNotCompute>(Res))
8351+ continue;
8352+ InferCountColl.push_back(Res);
8353+ }
8354+
8355+ if (InferCountColl.empty())
8356+ return getCouldNotCompute();
8357+
8358+ const SCEV *Count = getUMinFromMismatchedTypes(InferCountColl);
8359+
8360+ return {getCouldNotCompute(), Count, Count, /*MaxOrZero=*/false};
8361+ }
8362+
8363+ ScalarEvolution::ExitLimit
8364+ ScalarEvolution::computeExitLimitFromMemAccessCached(ExitLimitCacheTy &Cache,
8365+ const Loop *L) {
8366+ // We don't really need them but the cache does.
8367+ constexpr Value *ExitCond = nullptr;
8368+ constexpr const bool ExitIfTrue = true;
8369+ constexpr const bool ControlsOnlyExit = true;
8370+ constexpr const bool AllowPredicates = true;
8371+
8372+ if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
8373+ AllowPredicates))
8374+ return *MaybeEL;
8375+
8376+ ExitLimit EL = computeExitLimitFromMemAccessImpl(L);
8377+ Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
8378+ return EL;
8379+ }
8380+
81948381const SCEV *ScalarEvolution::getExitCount(const Loop *L,
81958382 const BasicBlock *ExitingBlock,
81968383 ExitCountKind Kind) {
@@ -8773,6 +8960,16 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
87738960 if (!Latch || !DT.dominates(ExitingBlock, Latch))
87748961 return getCouldNotCompute();
87758962
8963+ // FIXME: To make the case more typical, we only analyze loops that have one
8964+ // exiting block and the block must be the latch. It is easier to capture
8965+ // loops with memory access that will be executed in every iteration.
8966+ const SCEV *PotentiallyBetterConstantMax = getCouldNotCompute();
8967+ if (UseMemoryAccessUBForBEInference && Latch == L->getExitingBlock()) {
8968+ assert(Latch == ExitingBlock);
8969+ auto EL = computeExitLimitFromMemAccess(L);
8970+ PotentiallyBetterConstantMax = EL.ConstantMaxNotTaken;
8971+ }
8972+
87768973 bool IsOnlyExit = (L->getExitingBlock() != nullptr);
87778974 Instruction *Term = ExitingBlock->getTerminator();
87788975 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
@@ -8781,9 +8978,13 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
87818978 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
87828979 "It should have one successor in loop and one exit block!");
87838980 // Proceed to the next level to examine the exit condition expression.
8784- return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
8785- /*ControlsOnlyExit=*/IsOnlyExit,
8786- AllowPredicates);
8981+ ExitLimit EL = computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
8982+ /*ControlsOnlyExit=*/IsOnlyExit,
8983+ AllowPredicates);
8984+ if (!isa<SCEVCouldNotCompute>(PotentiallyBetterConstantMax))
8985+ EL.ConstantMaxNotTaken = getUMinFromMismatchedTypes(
8986+ EL.ConstantMaxNotTaken, PotentiallyBetterConstantMax);
8987+ return EL;
87878988 }
87888989
87898990 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
@@ -8796,9 +8997,13 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
87968997 Exit = SBB;
87978998 }
87988999 assert(Exit && "Exiting block must have at least one exit");
8799- return computeExitLimitFromSingleExitSwitch(
8800- L, SI, Exit,
8801- /*ControlsOnlyExit=*/IsOnlyExit);
9000+ ExitLimit EL =
9001+ computeExitLimitFromSingleExitSwitch(L, SI, Exit,
9002+ /*ControlsOnlyExit=*/IsOnlyExit);
9003+ if (!isa<SCEVCouldNotCompute>(PotentiallyBetterConstantMax))
9004+ EL.ConstantMaxNotTaken = getUMinFromMismatchedTypes(
9005+ EL.ConstantMaxNotTaken, PotentiallyBetterConstantMax);
9006+ return EL;
88029007 }
88039008
88049009 return getCouldNotCompute();
@@ -8812,6 +9017,13 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond(
88129017 ControlsOnlyExit, AllowPredicates);
88139018}
88149019
9020+ ScalarEvolution::ExitLimit
9021+ ScalarEvolution::computeExitLimitFromMemAccess(const Loop *L) {
9022+ ScalarEvolution::ExitLimitCacheTy Cache(L, /* ExitIfTrue */ true,
9023+ /* AllowPredicates */ true);
9024+ return computeExitLimitFromMemAccessCached(Cache, L);
9025+ }
9026+
88159027std::optional<ScalarEvolution::ExitLimit>
88169028ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
88179029 bool ExitIfTrue, bool ControlsOnlyExit,
@@ -13482,6 +13694,17 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
1348213694 OS << ": ";
1348313695 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
1348413696 }
13697+
13698+ if (UseMemoryAccessUBForBEInference) {
13699+ unsigned SmallMaxTrip = SE->getSmallConstantMaxTripCount(L);
13700+ OS << "Loop ";
13701+ L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13702+ OS << ": ";
13703+ if (SmallMaxTrip)
13704+ OS << "Small constant max trip is " << SmallMaxTrip << "\n";
13705+ else
13706+ OS << "Small constant max trip couldn't be computed.\n";
13707+ }
1348513708}
1348613709
1348713710namespace llvm {
0 commit comments