@@ -193,30 +193,28 @@ RuntimeCheckingPtrGroup::RuntimeCheckingPtrGroup(
193193// / Returns \p A + \p B, if it is guaranteed not to unsigned wrap. Otherwise
194194// / return nullptr. \p A and \p B must have the same type.
195195static const SCEV *addSCEVNoOverflow (const SCEV *A, const SCEV *B,
196- ScalarEvolution &SE,
197- const Instruction *CtxI) {
198- if (!SE.willNotOverflow (Instruction::Add, /* IsSigned=*/ false , A, B, CtxI))
196+ ScalarEvolution &SE) {
197+ if (!SE.willNotOverflow (Instruction::Add, /* IsSigned=*/ false , A, B))
199198 return nullptr ;
200199 return SE.getAddExpr (A, B);
201200}
202201
203202// / Returns \p A * \p B, if it is guaranteed not to unsigned wrap. Otherwise
204203// / return nullptr. \p A and \p B must have the same type.
205204static const SCEV *mulSCEVOverflow (const SCEV *A, const SCEV *B,
206- ScalarEvolution &SE,
207- const Instruction *CtxI) {
208- if (!SE.willNotOverflow (Instruction::Mul, /* IsSigned=*/ false , A, B, CtxI))
205+ ScalarEvolution &SE) {
206+ if (!SE.willNotOverflow (Instruction::Mul, /* IsSigned=*/ false , A, B))
209207 return nullptr ;
210208 return SE.getMulExpr (A, B);
211209}
212210
213211// / Return true, if evaluating \p AR at \p MaxBTC cannot wrap, because \p AR at
214212// / \p MaxBTC is guaranteed inbounds of the accessed object.
215- static bool
216- evaluatePtrAddRecAtMaxBTCWillNotWrap ( const SCEVAddRecExpr *AR,
217- const SCEV *MaxBTC, const SCEV *EltSize ,
218- ScalarEvolution &SE, const DataLayout &DL ,
219- DominatorTree *DT, AssumptionCache *AC ) {
213+ static bool evaluatePtrAddRecAtMaxBTCWillNotWrap (
214+ const SCEVAddRecExpr *AR, const SCEV *MaxBTC, const SCEV *EltSize ,
215+ ScalarEvolution &SE, const DataLayout &DL, DominatorTree *DT ,
216+ AssumptionCache *AC ,
217+ std::optional<ScalarEvolution::LoopGuards> &LoopGuards ) {
220218 auto *PointerBase = SE.getPointerBase (AR->getStart ());
221219 auto *StartPtr = dyn_cast<SCEVUnknown>(PointerBase);
222220 if (!StartPtr)
@@ -234,12 +232,11 @@ evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR,
234232 Type *WiderTy = SE.getWiderType (MaxBTC->getType (), Step->getType ());
235233 const SCEV *DerefBytesSCEV = SE.getConstant (WiderTy, DerefBytes);
236234
237- // Context which dominates the entire loop.
238- auto *CtxI = L->getLoopPredecessor ()->getTerminator ();
239235 // Check if we have a suitable dereferencable assumption we can use.
240236 if (!StartPtrV->canBeFreed ()) {
241237 RetainedKnowledge DerefRK = getKnowledgeValidInContext (
242- StartPtrV, {Attribute::Dereferenceable}, *AC, CtxI, DT);
238+ StartPtrV, {Attribute::Dereferenceable}, *AC,
239+ L->getLoopPredecessor ()->getTerminator (), DT);
243240 if (DerefRK) {
244241 DerefBytesSCEV = SE.getUMaxExpr (
245242 DerefBytesSCEV, SE.getConstant (WiderTy, DerefRK.ArgValue ));
@@ -263,23 +260,36 @@ evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR,
263260 SE.getMinusSCEV (AR->getStart (), StartPtr), WiderTy);
264261
265262 const SCEV *OffsetAtLastIter =
266- mulSCEVOverflow (MaxBTC, SE.getAbsExpr (Step, /* IsNSW=*/ false ), SE, CtxI);
267- if (!OffsetAtLastIter)
268- return false ;
263+ mulSCEVOverflow (MaxBTC, SE.getAbsExpr (Step, /* IsNSW=*/ false ), SE);
264+ if (!OffsetAtLastIter) {
265+ // Re-try with constant max backedge-taken count if using the symbolic one
266+ // failed.
267+ MaxBTC = SE.getNoopOrZeroExtend (
268+ SE.getConstantMaxBackedgeTakenCount (AR->getLoop ()), WiderTy);
269+ OffsetAtLastIter =
270+ mulSCEVOverflow (MaxBTC, SE.getAbsExpr (Step, /* IsNSW=*/ false ), SE);
271+ if (!OffsetAtLastIter)
272+ return false ;
273+ }
269274
270275 const SCEV *OffsetEndBytes = addSCEVNoOverflow (
271- OffsetAtLastIter, SE.getNoopOrZeroExtend (EltSize, WiderTy), SE, CtxI );
276+ OffsetAtLastIter, SE.getNoopOrZeroExtend (EltSize, WiderTy), SE);
272277 if (!OffsetEndBytes)
273278 return false ;
274279
275280 if (IsKnownNonNegative) {
276281 // For positive steps, check if
277282 // (AR->getStart() - StartPtr) + (MaxBTC * Step) + EltSize <= DerefBytes,
278283 // while making sure none of the computations unsigned wrap themselves.
279- const SCEV *EndBytes =
280- addSCEVNoOverflow (StartOffset, OffsetEndBytes, SE, CtxI);
284+ const SCEV *EndBytes = addSCEVNoOverflow (StartOffset, OffsetEndBytes, SE);
281285 if (!EndBytes)
282286 return false ;
287+
288+ if (!LoopGuards)
289+ LoopGuards.emplace (
290+ ScalarEvolution::LoopGuards::collect (AR->getLoop (), SE));
291+
292+ EndBytes = SE.applyLoopGuards (EndBytes, *LoopGuards);
283293 return SE.isKnownPredicate (CmpInst::ICMP_ULE, EndBytes, DerefBytesSCEV);
284294 }
285295
@@ -296,7 +306,8 @@ std::pair<const SCEV *, const SCEV *> llvm::getStartAndEndForAccess(
296306 const SCEV *MaxBTC, ScalarEvolution *SE,
297307 DenseMap<std::pair<const SCEV *, Type *>,
298308 std::pair<const SCEV *, const SCEV *>> *PointerBounds,
299- DominatorTree *DT, AssumptionCache *AC) {
309+ DominatorTree *DT, AssumptionCache *AC,
310+ std::optional<ScalarEvolution::LoopGuards> &LoopGuards) {
300311 std::pair<const SCEV *, const SCEV *> *PtrBoundsPair;
301312 if (PointerBounds) {
302313 auto [Iter, Ins] = PointerBounds->insert (
@@ -332,7 +343,7 @@ std::pair<const SCEV *, const SCEV *> llvm::getStartAndEndForAccess(
332343 // separately checks that accesses cannot not wrap, so unsigned max
333344 // represents an upper bound.
334345 if (evaluatePtrAddRecAtMaxBTCWillNotWrap (AR, MaxBTC, EltSizeSCEV, *SE, DL,
335- DT, AC)) {
346+ DT, AC, LoopGuards )) {
336347 ScEnd = AR->evaluateAtIteration (MaxBTC, *SE);
337348 } else {
338349 ScEnd = SE->getAddExpr (
@@ -381,7 +392,7 @@ void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, const SCEV *PtrExpr,
381392 const SCEV *BTC = PSE.getBackedgeTakenCount ();
382393 const auto &[ScStart, ScEnd] = getStartAndEndForAccess (
383394 Lp, PtrExpr, AccessTy, BTC, SymbolicMaxBTC, PSE.getSE (),
384- &DC.getPointerBounds (), DC.getDT (), DC.getAC ());
395+ &DC.getPointerBounds (), DC.getDT (), DC.getAC (), LoopGuards );
385396 assert (!isa<SCEVCouldNotCompute>(ScStart) &&
386397 !isa<SCEVCouldNotCompute>(ScEnd) &&
387398 " must be able to compute both start and end expressions" );
@@ -1987,13 +1998,13 @@ bool MemoryDepChecker::areAccessesCompletelyBeforeOrAfter(const SCEV *Src,
19871998 ScalarEvolution &SE = *PSE.getSE ();
19881999 const auto &[SrcStart_, SrcEnd_] =
19892000 getStartAndEndForAccess (InnermostLoop, Src, SrcTy, BTC, SymbolicMaxBTC,
1990- &SE, &PointerBounds, DT, AC);
2001+ &SE, &PointerBounds, DT, AC, LoopGuards );
19912002 if (isa<SCEVCouldNotCompute>(SrcStart_) || isa<SCEVCouldNotCompute>(SrcEnd_))
19922003 return false ;
19932004
19942005 const auto &[SinkStart_, SinkEnd_] =
19952006 getStartAndEndForAccess (InnermostLoop, Sink, SinkTy, BTC, SymbolicMaxBTC,
1996- &SE, &PointerBounds, DT, AC);
2007+ &SE, &PointerBounds, DT, AC, LoopGuards );
19972008 if (isa<SCEVCouldNotCompute>(SinkStart_) ||
19982009 isa<SCEVCouldNotCompute>(SinkEnd_))
19992010 return false ;
@@ -3040,8 +3051,9 @@ LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE,
30403051 TTI->getRegisterBitWidth (TargetTransformInfo::RGK_FixedWidthVector) * 2 ;
30413052
30423053 DepChecker = std::make_unique<MemoryDepChecker>(
3043- *PSE, AC, DT, L, SymbolicStrides, MaxTargetVectorWidthInBits);
3044- PtrRtChecking = std::make_unique<RuntimePointerChecking>(*DepChecker, SE);
3054+ *PSE, AC, DT, L, SymbolicStrides, MaxTargetVectorWidthInBits, LoopGuards);
3055+ PtrRtChecking =
3056+ std::make_unique<RuntimePointerChecking>(*DepChecker, SE, LoopGuards);
30453057 if (canAnalyzeLoop ())
30463058 CanVecMem = analyzeLoop (AA, LI, TLI, DT);
30473059}
0 commit comments