@@ -249,6 +249,10 @@ static cl::opt<bool> UseContextForNoWrapFlagInference(
249249 cl::desc("Infer nuw/nsw flags using context where suitable"),
250250 cl::init(true));
251251
252+ static cl::opt<bool> UseMemoryAccessUBForBEInference(
253+ "scalar-evolution-infer-max-trip-count-from-memory-access", cl::Hidden,
254+ cl::desc("Infer loop max trip count from memory access"), cl::init(false));
255+
252256//===----------------------------------------------------------------------===//
253257// SCEV class definitions
254258//===----------------------------------------------------------------------===//
@@ -8135,7 +8139,16 @@ ScalarEvolution::getSmallConstantTripCount(const Loop *L,
81358139unsigned ScalarEvolution::getSmallConstantMaxTripCount(const Loop *L) {
81368140 const auto *MaxExitCount =
81378141 dyn_cast<SCEVConstant>(getConstantMaxBackedgeTakenCount(L));
8138- return getConstantTripCount(MaxExitCount);
8142+ unsigned MaxExitCountN = getConstantTripCount(MaxExitCount);
8143+ if (UseMemoryAccessUBForBEInference) {
8144+ auto *MaxInferCount = getConstantMaxTripCountFromMemAccess(L);
8145+ if (auto *InferCount = dyn_cast<SCEVConstant>(MaxInferCount)) {
8146+ unsigned InferValue = InferCount->getValue()->getZExtValue();
8147+ MaxExitCountN =
8148+ MaxExitCountN == 0 ? InferValue : std::min(MaxExitCountN, InferValue);
8149+ }
8150+ }
8151+ return MaxExitCountN;
81398152}
81408153
81418154unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L) {
@@ -8190,6 +8203,167 @@ ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
81908203 return getSmallConstantTripMultiple(L, ExitCount);
81918204}
81928205
8206+ /// Collect all load/store instructions that must be executed in every iteration
8207+ /// of loop \p L .
8208+ static void
8209+ collectExecLoadStoreInsideLoop(const Loop *L, DominatorTree &DT,
8210+ SmallVector<Instruction *, 4> &MemInsts) {
8211+ // It is difficult to tell if the load/store instruction is executed on every
8212+ // iteration inside an irregular loop.
8213+ if (!L->isLoopSimplifyForm() || !L->isInnermost())
8214+ return;
8215+
8216+ // FIXME: To make the case more typical, we only analyze loops that have one
8217+ // exiting block and the block must be the latch. It is easier to capture
8218+ // loops with memory access that will be executed in every iteration.
8219+ const BasicBlock *LoopLatch = L->getLoopLatch();
8220+ assert(LoopLatch && "normal form loop doesn't have a latch");
8221+ if (L->getExitingBlock() != LoopLatch)
8222+ return;
8223+
8224+ // We will not continue if sanitizer is enabled.
8225+ const Function *F = LoopLatch->getParent();
8226+ if (F->hasFnAttribute(Attribute::SanitizeAddress) ||
8227+ F->hasFnAttribute(Attribute::SanitizeThread) ||
8228+ F->hasFnAttribute(Attribute::SanitizeMemory) ||
8229+ F->hasFnAttribute(Attribute::SanitizeHWAddress) ||
8230+ F->hasFnAttribute(Attribute::SanitizeMemTag))
8231+ return;
8232+
8233+ for (auto *BB : L->getBlocks()) {
8234+ // We need to make sure that max execution time of MemAccessBB in loop
8235+ // represents latch max excution time. The BB below should be skipped:
8236+ // Entry
8237+ // │
8238+ // ┌─────▼─────┐
8239+ // │Loop Header◄─────┐
8240+ // └──┬──────┬─┘ │
8241+ // │ │ │
8242+ // ┌────────▼──┐ ┌─▼─────┐ │
8243+ // │MemAccessBB│ │OtherBB│ │
8244+ // └────────┬──┘ └─┬─────┘ │
8245+ // │ │ │
8246+ // ┌─▼──────▼─┐ │
8247+ // │Loop Latch├─────┘
8248+ // └────┬─────┘
8249+ // ▼
8250+ // Exit
8251+ if (!DT.dominates(BB, LoopLatch))
8252+ continue;
8253+
8254+ for (Instruction &I : *BB) {
8255+ if (isa<LoadInst>(&I) || isa<StoreInst>(&I))
8256+ MemInsts.push_back(&I);
8257+ }
8258+ }
8259+ }
8260+
8261+ /// Return a SCEV representing the memory size of pointer \p V .
8262+ static const SCEV *getCertainSizeOfMem(const SCEV *V, Type *RTy,
8263+ const DataLayout &DL,
8264+ const TargetLibraryInfo &TLI,
8265+ ScalarEvolution *SE) {
8266+ const SCEVUnknown *PtrBase = dyn_cast<SCEVUnknown>(V);
8267+ if (!PtrBase)
8268+ return nullptr;
8269+ Value *Ptr = PtrBase->getValue();
8270+ uint64_t Size = 0;
8271+ if (!llvm::getObjectSize(Ptr, Size, DL, &TLI))
8272+ return nullptr;
8273+ return SE->getConstant(RTy, Size);
8274+ }
8275+
8276+ /// Get the range of given index represented by \p AddRec.
8277+ static const SCEV *getIndexRange(const SCEVAddRecExpr *AddRec,
8278+ ScalarEvolution *SE) {
8279+ const SCEV *Range = SE->getConstant(SE->getUnsignedRangeMax(AddRec) -
8280+ SE->getUnsignedRangeMin(AddRec));
8281+ const SCEV *Step = AddRec->getStepRecurrence(*SE);
8282+ return SE->getUDivCeilSCEV(Range, Step);
8283+ }
8284+
8285+ /// Check whether the index can wrap and if we can still infer max trip count
8286+ /// given the max trip count inferred from memory access.
8287+ static const SCEV *checkIndexWrap(Value *Ptr, ScalarEvolution *SE,
8288+ const SCEVConstant *MaxExecCount) {
8289+ SmallVector<const SCEV *> InferCountColl;
8290+ auto *PtrGEP = dyn_cast<GetElementPtrInst>(Ptr);
8291+ if (!PtrGEP)
8292+ return SE->getCouldNotCompute();
8293+ for (Value *Index : PtrGEP->indices()) {
8294+ Value *V = Index;
8295+ if (isa<ZExtInst>(V) || isa<SExtInst>(V))
8296+ V = cast<Instruction>(Index)->getOperand(0);
8297+ auto *SCEV = SE->getSCEV(V);
8298+ if (isa<SCEVCouldNotCompute>(SCEV))
8299+ return SE->getCouldNotCompute();
8300+ auto *AddRec = dyn_cast<SCEVAddRecExpr>(SCEV);
8301+ if (!AddRec)
8302+ continue;
8303+ auto *IndexRange = getIndexRange(AddRec, SE);
8304+ if (AddRec->hasNoSelfWrap()) {
8305+ InferCountColl.push_back(
8306+ SE->getUMinFromMismatchedTypes(IndexRange, MaxExecCount));
8307+ } else {
8308+ auto *IndexRangeC = dyn_cast<SCEVConstant>(IndexRange);
8309+ if (!IndexRangeC)
8310+ continue;
8311+ if (MaxExecCount->getValue()->getZExtValue() >
8312+ IndexRangeC->getValue()->getZExtValue())
8313+ InferCountColl.push_back(IndexRange);
8314+ else
8315+ InferCountColl.push_back(MaxExecCount);
8316+ }
8317+ }
8318+
8319+ if (InferCountColl.empty())
8320+ return SE->getCouldNotCompute();
8321+
8322+ return SE->getUMinFromMismatchedTypes(InferCountColl);
8323+ }
8324+
8325+ const SCEV *
8326+ ScalarEvolution::getConstantMaxTripCountFromMemAccess(const Loop *L) {
8327+ SmallVector<Instruction *, 4> MemInsts;
8328+ collectExecLoadStoreInsideLoop(L, DT, MemInsts);
8329+
8330+ SmallVector<const SCEV *> InferCountColl;
8331+ const DataLayout &DL = getDataLayout();
8332+
8333+ for (Instruction *I : MemInsts) {
8334+ Value *Ptr = getLoadStorePointerOperand(I);
8335+ assert(Ptr && "empty pointer operand");
8336+ auto *AddRec = dyn_cast<SCEVAddRecExpr>(getSCEV(Ptr));
8337+ if (!AddRec || !AddRec->isAffine())
8338+ continue;
8339+ const SCEV *PtrBase = getPointerBase(AddRec);
8340+ const SCEV *Step = AddRec->getStepRecurrence(*this);
8341+ const SCEV *MemSize =
8342+ getCertainSizeOfMem(PtrBase, Step->getType(), DL, TLI, this);
8343+ if (!MemSize)
8344+ continue;
8345+ // Now we can infer a max execution time by MemLength/StepLength.
8346+ auto *MaxExecCount = dyn_cast<SCEVConstant>(getUDivCeilSCEV(MemSize, Step));
8347+ if (!MaxExecCount || MaxExecCount->getAPInt().getActiveBits() > 32)
8348+ continue;
8349+ // Now we check the wrap. We can still explore the max trip count in the
8350+ // following two cases:
8351+ // 1. If the index can potentially wrap but the max trip count inferred from
8352+ // memory access is within the range of index.
8353+ // 2. If the index can't wrap, then the max trip count is:
8354+ // min(range of index, max value inferred from memory access).
8355+ auto *Res = checkIndexWrap(Ptr, this, MaxExecCount);
8356+ if (isa<SCEVCouldNotCompute>(Res))
8357+ continue;
8358+ InferCountColl.push_back(Res);
8359+ }
8360+
8361+ if (InferCountColl.empty())
8362+ return getCouldNotCompute();
8363+
8364+ return getUMinFromMismatchedTypes(InferCountColl);
8365+ }
8366+
81938367const SCEV *ScalarEvolution::getExitCount(const Loop *L,
81948368 const BasicBlock *ExitingBlock,
81958369 ExitCountKind Kind) {
@@ -13477,6 +13651,17 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
1347713651 OS << ": ";
1347813652 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
1347913653 }
13654+
13655+ if (UseMemoryAccessUBForBEInference) {
13656+ unsigned SmallMaxTrip = SE->getSmallConstantMaxTripCount(L);
13657+ OS << "Loop ";
13658+ L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13659+ OS << ": ";
13660+ if (SmallMaxTrip)
13661+ OS << "Small constant max trip is " << SmallMaxTrip << "\n";
13662+ else
13663+ OS << "Small constant max trip couldn't be computed.\n";
13664+ }
1348013665}
1348113666
1348213667namespace llvm {
0 commit comments