@@ -249,6 +249,10 @@ static cl::opt<bool> UseContextForNoWrapFlagInference(
249249 cl::desc("Infer nuw/nsw flags using context where suitable"),
250250 cl::init(true));
251251
252+ static cl::opt<bool> UseMemoryAccessUBForBEInference(
253+ "scalar-evolution-infer-max-trip-count-from-memory-access", cl::Hidden,
254+ cl::desc("Infer loop max trip count from memory access"), cl::init(false));
255+
252256//===----------------------------------------------------------------------===//
253257// SCEV class definitions
254258//===----------------------------------------------------------------------===//
@@ -8260,6 +8264,210 @@ ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
82608264 return getSmallConstantTripMultiple(L, ExitCount);
82618265}
82628266
8267+ /// Collect all load/store instructions that must be executed in every iteration
8268+ /// of loop \p L .
8269+ static void
8270+ collectExecLoadStoreInsideLoop(const Loop *L, DominatorTree &DT,
8271+ SmallVector<Instruction *, 4> &MemInsts) {
8272+ // It is difficult to tell if the load/store instruction is executed on every
8273+ // iteration inside an irregular loop.
8274+ if (!L->isLoopSimplifyForm() || !L->isInnermost())
8275+ return;
8276+
8277+ const BasicBlock *LoopLatch = L->getLoopLatch();
8278+ assert(LoopLatch && "normal form loop doesn't have a latch");
8279+ assert(L->getExitingBlock() == LoopLatch);
8280+
8281+ // We will not continue if sanitizer is enabled.
8282+ const Function *F = LoopLatch->getParent();
8283+ if (F->hasFnAttribute(Attribute::SanitizeAddress) ||
8284+ F->hasFnAttribute(Attribute::SanitizeThread) ||
8285+ F->hasFnAttribute(Attribute::SanitizeMemory) ||
8286+ F->hasFnAttribute(Attribute::SanitizeHWAddress) ||
8287+ F->hasFnAttribute(Attribute::SanitizeMemTag))
8288+ return;
8289+
8290+ for (auto *BB : L->getBlocks()) {
8291+ // We need to make sure that max execution time of MemAccessBB in loop
8292+ // represents latch max excution time. The BB below should be skipped:
8293+ // Entry
8294+ // │
8295+ // ┌─────▼─────┐
8296+ // │Loop Header◄─────┐
8297+ // └──┬──────┬─┘ │
8298+ // │ │ │
8299+ // ┌────────▼──┐ ┌─▼─────┐ │
8300+ // │MemAccessBB│ │OtherBB│ │
8301+ // └────────┬──┘ └─┬─────┘ │
8302+ // │ │ │
8303+ // ┌─▼──────▼─┐ │
8304+ // │Loop Latch├─────┘
8305+ // └────┬─────┘
8306+ // ▼
8307+ // Exit
8308+ if (!DT.dominates(BB, LoopLatch))
8309+ continue;
8310+
8311+ for (Instruction &I : *BB) {
8312+ if (isa<LoadInst>(&I) || isa<StoreInst>(&I))
8313+ MemInsts.push_back(&I);
8314+ }
8315+ }
8316+ }
8317+
8318+ /// Return a SCEV representing the memory size of pointer \p V .
8319+ static const SCEV *getCertainSizeOfMem(const SCEV *V, Type *RTy,
8320+ const DataLayout &DL,
8321+ const TargetLibraryInfo &TLI,
8322+ ScalarEvolution *SE) {
8323+ const SCEVUnknown *PtrBase = dyn_cast<SCEVUnknown>(V);
8324+ if (!PtrBase)
8325+ return nullptr;
8326+ Value *Ptr = PtrBase->getValue();
8327+ uint64_t Size = 0;
8328+ if (!llvm::getObjectSize(Ptr, Size, DL, &TLI))
8329+ return nullptr;
8330+ return SE->getConstant(RTy, Size);
8331+ }
8332+
8333+ /// Get the range of given index represented by \p AddRec.
8334+ static const SCEV *getIndexRange(const SCEVAddRecExpr *AddRec,
8335+ ScalarEvolution *SE) {
8336+ const SCEV *Range = SE->getConstant(SE->getUnsignedRangeMax(AddRec) -
8337+ SE->getUnsignedRangeMin(AddRec));
8338+ const SCEV *Step = AddRec->getStepRecurrence(*SE);
8339+ if (SE->isKnownNegative(Step))
8340+ Step = SE->getNegativeSCEV(Step);
8341+ return SE->getUDivCeilSCEV(Range, Step);
8342+ }
8343+
8344+ /// Get the underlying SCEVAddExpr from a cast expression if possible.
8345+ const SCEV *peelCastExpr(const SCEVCastExpr *S, ScalarEvolution *SE) {
8346+ const SCEV *Op = S->getOperand();
8347+ if (isa<SCEVCouldNotCompute>(Op))
8348+ return Op;
8349+ if (isa<SCEVAddExpr>(Op))
8350+ return Op;
8351+ if (isa<SCEVCastExpr>(Op))
8352+ return peelCastExpr(cast<SCEVCastExpr>(Op), SE);
8353+ return SE->getCouldNotCompute();
8354+ }
8355+
8356+ static Value *peelExt(Value *V) {
8357+ if (isa<ZExtInst>(V) || isa<SExtInst>(V))
8358+ return peelExt(cast<Instruction>(V)->getOperand(0));
8359+ return V;
8360+ }
8361+
8362+ static bool isIndexInductionVariable(PHINode *InductionVar, Value *Index) {
8363+ if (InductionVar == Index)
8364+ return true;
8365+ if (peelExt(Index) == InductionVar)
8366+ return true;
8367+ return false;
8368+ }
8369+
8370+ /// Check whether the index can wrap and if we can still infer max trip count
8371+ /// given the max trip count inferred from memory access.
8372+ static const SCEV *checkIndexRange(Value *Ptr, PHINode *InductionVar,
8373+ ScalarEvolution *SE,
8374+ const SCEVConstant *MaxExecCount) {
8375+ SmallVector<const SCEV *> InferCountColl;
8376+ auto *PtrGEP = dyn_cast<GetElementPtrInst>(Ptr);
8377+ if (!PtrGEP)
8378+ return SE->getCouldNotCompute();
8379+ for (Value *Index : PtrGEP->indices()) {
8380+ Value *V = Index;
8381+ if (!isIndexInductionVariable(InductionVar, Index))
8382+ continue;
8383+ if (isa<ZExtInst>(V) || isa<SExtInst>(V))
8384+ V = cast<Instruction>(Index)->getOperand(0);
8385+ auto *SCEV = SE->getSCEV(V);
8386+ if (isa<SCEVCouldNotCompute>(SCEV))
8387+ return SE->getCouldNotCompute();
8388+ if (isa<SCEVCastExpr>(SCEV)) {
8389+ SCEV = peelCastExpr(cast<SCEVCastExpr>(SCEV), SE);
8390+ if (isa<SCEVCouldNotCompute>(SCEV))
8391+ return SE->getCouldNotCompute();
8392+ }
8393+ auto *AddRec = dyn_cast<SCEVAddRecExpr>(SCEV);
8394+ if (!AddRec)
8395+ return SE->getCouldNotCompute();
8396+ auto *IndexRange = getIndexRange(AddRec, SE);
8397+ auto *IndexRangeC = dyn_cast<SCEVConstant>(IndexRange);
8398+ if (!IndexRangeC)
8399+ return SE->getCouldNotCompute();
8400+ InferCountColl.push_back(IndexRange);
8401+ break;
8402+ }
8403+
8404+ if (InferCountColl.empty())
8405+ return SE->getCouldNotCompute();
8406+
8407+ InferCountColl.push_back(MaxExecCount);
8408+
8409+ return SE->getUMinFromMismatchedTypes(InferCountColl);
8410+ }
8411+
8412+ ScalarEvolution::ExitLimit
8413+ ScalarEvolution::computeExitLimitFromMemAccessImpl(const Loop *L) {
8414+ SmallVector<Instruction *, 4> MemInsts;
8415+ collectExecLoadStoreInsideLoop(L, DT, MemInsts);
8416+
8417+ SmallVector<const SCEV *> InferCountColl;
8418+ const DataLayout &DL = getDataLayout();
8419+
8420+ for (Instruction *I : MemInsts) {
8421+ Value *Ptr = getLoadStorePointerOperand(I);
8422+ assert(Ptr && "empty pointer operand");
8423+ auto *AddRec = dyn_cast<SCEVAddRecExpr>(getSCEV(Ptr));
8424+ if (!AddRec || !AddRec->isAffine())
8425+ continue;
8426+ const SCEV *PtrBase = getPointerBase(AddRec);
8427+ const SCEV *Step = AddRec->getStepRecurrence(*this);
8428+ const SCEV *MemSize =
8429+ getCertainSizeOfMem(PtrBase, Step->getType(), DL, TLI, this);
8430+ if (!MemSize)
8431+ continue;
8432+ if (isKnownNegative(Step))
8433+ Step = getNegativeSCEV(Step);
8434+ // Now we can infer a max execution time by MemLength/StepLength.
8435+ auto *MaxExecCount = dyn_cast<SCEVConstant>(getUDivCeilSCEV(MemSize, Step));
8436+ if (!MaxExecCount || MaxExecCount->getAPInt().getActiveBits() > 32)
8437+ continue;
8438+ auto *Res = checkIndexRange(Ptr, L->getInductionVariable(*this), this,
8439+ MaxExecCount);
8440+ if (isa<SCEVCouldNotCompute>(Res))
8441+ continue;
8442+ InferCountColl.push_back(Res);
8443+ }
8444+
8445+ if (InferCountColl.empty())
8446+ return getCouldNotCompute();
8447+
8448+ const SCEV *Count = getUMinFromMismatchedTypes(InferCountColl);
8449+
8450+ return {getCouldNotCompute(), Count, Count, /*MaxOrZero=*/false};
8451+ }
8452+
8453+ ScalarEvolution::ExitLimit
8454+ ScalarEvolution::computeExitLimitFromMemAccessCached(ExitLimitCacheTy &Cache,
8455+ const Loop *L) {
8456+ // We don't really need them but the cache does.
8457+ constexpr Value *ExitCond = nullptr;
8458+ constexpr const bool ExitIfTrue = true;
8459+ constexpr const bool ControlsOnlyExit = true;
8460+ constexpr const bool AllowPredicates = true;
8461+
8462+ if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
8463+ AllowPredicates))
8464+ return *MaybeEL;
8465+
8466+ ExitLimit EL = computeExitLimitFromMemAccessImpl(L);
8467+ Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
8468+ return EL;
8469+ }
8470+
82638471const SCEV *ScalarEvolution::getExitCount(const Loop *L,
82648472 const BasicBlock *ExitingBlock,
82658473 ExitCountKind Kind) {
@@ -8842,6 +9050,16 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
88429050 if (!Latch || !DT.dominates(ExitingBlock, Latch))
88439051 return getCouldNotCompute();
88449052
9053+ // FIXME: To make the case more typical, we only analyze loops that have one
9054+ // exiting block and the block must be the latch. It is easier to capture
9055+ // loops with memory access that will be executed in every iteration.
9056+ const SCEV *PotentiallyBetterConstantMax = getCouldNotCompute();
9057+ if (UseMemoryAccessUBForBEInference && Latch == L->getExitingBlock()) {
9058+ assert(Latch == ExitingBlock);
9059+ auto EL = computeExitLimitFromMemAccess(L);
9060+ PotentiallyBetterConstantMax = EL.ConstantMaxNotTaken;
9061+ }
9062+
88459063 bool IsOnlyExit = (L->getExitingBlock() != nullptr);
88469064 Instruction *Term = ExitingBlock->getTerminator();
88479065 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
@@ -8850,9 +9068,14 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
88509068 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
88519069 "It should have one successor in loop and one exit block!");
88529070 // Proceed to the next level to examine the exit condition expression.
8853- return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
8854- /*ControlsOnlyExit=*/IsOnlyExit,
8855- AllowPredicates);
9071+ ExitLimit EL = computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
9072+ /*ControlsOnlyExit=*/IsOnlyExit,
9073+ AllowPredicates);
9074+ if (!isa<SCEVCouldNotCompute>(EL.ConstantMaxNotTaken) &&
9075+ !isa<SCEVCouldNotCompute>(PotentiallyBetterConstantMax))
9076+ EL.ConstantMaxNotTaken = getUMinFromMismatchedTypes(
9077+ EL.ConstantMaxNotTaken, PotentiallyBetterConstantMax);
9078+ return EL;
88569079 }
88579080
88589081 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
@@ -8865,9 +9088,14 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
88659088 Exit = SBB;
88669089 }
88679090 assert(Exit && "Exiting block must have at least one exit");
8868- return computeExitLimitFromSingleExitSwitch(
8869- L, SI, Exit,
8870- /*ControlsOnlyExit=*/IsOnlyExit);
9091+ ExitLimit EL =
9092+ computeExitLimitFromSingleExitSwitch(L, SI, Exit,
9093+ /*ControlsOnlyExit=*/IsOnlyExit);
9094+ if (!isa<SCEVCouldNotCompute>(EL.ConstantMaxNotTaken) &&
9095+ !isa<SCEVCouldNotCompute>(PotentiallyBetterConstantMax))
9096+ EL.ConstantMaxNotTaken = getUMinFromMismatchedTypes(
9097+ EL.ConstantMaxNotTaken, PotentiallyBetterConstantMax);
9098+ return EL;
88719099 }
88729100
88739101 return getCouldNotCompute();
@@ -8881,6 +9109,13 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond(
88819109 ControlsOnlyExit, AllowPredicates);
88829110}
88839111
9112+ ScalarEvolution::ExitLimit
9113+ ScalarEvolution::computeExitLimitFromMemAccess(const Loop *L) {
9114+ ScalarEvolution::ExitLimitCacheTy Cache(L, /* ExitIfTrue */ true,
9115+ /* AllowPredicates */ true);
9116+ return computeExitLimitFromMemAccessCached(Cache, L);
9117+ }
9118+
88849119std::optional<ScalarEvolution::ExitLimit>
88859120ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
88869121 bool ExitIfTrue, bool ControlsOnlyExit,
@@ -13579,6 +13814,17 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
1357913814 OS << ": ";
1358013815 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
1358113816 }
13817+
13818+ if (UseMemoryAccessUBForBEInference) {
13819+ unsigned SmallMaxTrip = SE->getSmallConstantMaxTripCount(L);
13820+ OS << "Loop ";
13821+ L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13822+ OS << ": ";
13823+ if (SmallMaxTrip)
13824+ OS << "Small constant max trip is " << SmallMaxTrip << "\n";
13825+ else
13826+ OS << "Small constant max trip couldn't be computed.\n";
13827+ }
1358213828}
1358313829
1358413830namespace llvm {
0 commit comments