Skip to content

Commit 08001cf

Browse files
authored
[LAA,Loads] Use loop guards and max BTC if needed when checking deref. (#155672)
Remove the fall-back to constant max BTC if the backedge-taken-count cannot be computed. The constant max backedge-taken count is computed considering loop guards, so to avoid regressions we need to apply loop guards as needed. Also remove the special handling for Mul in willNotOverflow, as this should not longer be needed after 9143746 (#155300). PR: #155672
1 parent 2364736 commit 08001cf

File tree

5 files changed

+69
-64
lines changed

5 files changed

+69
-64
lines changed

llvm/include/llvm/Analysis/LoopAccessAnalysis.h

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,12 @@ class MemoryDepChecker {
183183
MemoryDepChecker(PredicatedScalarEvolution &PSE, AssumptionCache *AC,
184184
DominatorTree *DT, const Loop *L,
185185
const DenseMap<Value *, const SCEV *> &SymbolicStrides,
186-
unsigned MaxTargetVectorWidthInBits)
186+
unsigned MaxTargetVectorWidthInBits,
187+
std::optional<ScalarEvolution::LoopGuards> &LoopGuards)
187188
: PSE(PSE), AC(AC), DT(DT), InnermostLoop(L),
188189
SymbolicStrides(SymbolicStrides),
189-
MaxTargetVectorWidthInBits(MaxTargetVectorWidthInBits) {}
190+
MaxTargetVectorWidthInBits(MaxTargetVectorWidthInBits),
191+
LoopGuards(LoopGuards) {}
190192

191193
/// Register the location (instructions are given increasing numbers)
192194
/// of a write access.
@@ -373,7 +375,7 @@ class MemoryDepChecker {
373375
PointerBounds;
374376

375377
/// Cache for the loop guards of InnermostLoop.
376-
std::optional<ScalarEvolution::LoopGuards> LoopGuards;
378+
std::optional<ScalarEvolution::LoopGuards> &LoopGuards;
377379

378380
/// Check whether there is a plausible dependence between the two
379381
/// accesses.
@@ -531,8 +533,9 @@ class RuntimePointerChecking {
531533
AliasSetId(AliasSetId), Expr(Expr), NeedsFreeze(NeedsFreeze) {}
532534
};
533535

534-
RuntimePointerChecking(MemoryDepChecker &DC, ScalarEvolution *SE)
535-
: DC(DC), SE(SE) {}
536+
RuntimePointerChecking(MemoryDepChecker &DC, ScalarEvolution *SE,
537+
std::optional<ScalarEvolution::LoopGuards> &LoopGuards)
538+
: DC(DC), SE(SE), LoopGuards(LoopGuards) {}
536539

537540
/// Reset the state of the pointer runtime information.
538541
void reset() {
@@ -646,6 +649,9 @@ class RuntimePointerChecking {
646649
/// Holds a pointer to the ScalarEvolution analysis.
647650
ScalarEvolution *SE;
648651

652+
/// Cache for the loop guards of the loop.
653+
std::optional<ScalarEvolution::LoopGuards> &LoopGuards;
654+
649655
/// Set of run-time checks required to establish independence of
650656
/// otherwise may-aliasing pointers in the loop.
651657
SmallVector<RuntimePointerCheck, 4> Checks;
@@ -821,6 +827,9 @@ class LoopAccessInfo {
821827

822828
Loop *TheLoop;
823829

830+
/// Cache for the loop guards of TheLoop.
831+
std::optional<ScalarEvolution::LoopGuards> LoopGuards;
832+
824833
/// Determines whether we should generate partial runtime checks when not all
825834
/// memory accesses could be analyzed.
826835
bool AllowPartial;
@@ -938,7 +947,8 @@ LLVM_ABI std::pair<const SCEV *, const SCEV *> getStartAndEndForAccess(
938947
const SCEV *MaxBTC, ScalarEvolution *SE,
939948
DenseMap<std::pair<const SCEV *, Type *>,
940949
std::pair<const SCEV *, const SCEV *>> *PointerBounds,
941-
DominatorTree *DT, AssumptionCache *AC);
950+
DominatorTree *DT, AssumptionCache *AC,
951+
std::optional<ScalarEvolution::LoopGuards> &LoopGuards);
942952

943953
class LoopAccessInfoManager {
944954
/// The cache.

llvm/lib/Analysis/Loads.cpp

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,6 @@
2626

2727
using namespace llvm;
2828

29-
static cl::opt<bool>
30-
UseSymbolicMaxBTCForDerefInLoop("use-symbolic-maxbtc-deref-loop",
31-
cl::init(false));
32-
3329
static bool isAligned(const Value *Base, Align Alignment,
3430
const DataLayout &DL) {
3531
return Base->getPointerAlignment(DL) >= Alignment;
@@ -335,18 +331,10 @@ bool llvm::isDereferenceableAndAlignedInLoop(
335331
: SE.getBackedgeTakenCount(L);
336332
if (isa<SCEVCouldNotCompute>(MaxBECount))
337333
return false;
338-
339-
if (isa<SCEVCouldNotCompute>(BECount) && !UseSymbolicMaxBTCForDerefInLoop) {
340-
// TODO: Support symbolic max backedge taken counts for loops without
341-
// computable backedge taken counts.
342-
MaxBECount =
343-
Predicates
344-
? SE.getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
345-
: SE.getConstantMaxBackedgeTakenCount(L);
346-
}
347-
348-
const auto &[AccessStart, AccessEnd] = getStartAndEndForAccess(
349-
L, PtrScev, LI->getType(), BECount, MaxBECount, &SE, nullptr, &DT, AC);
334+
std::optional<ScalarEvolution::LoopGuards> LoopGuards;
335+
const auto &[AccessStart, AccessEnd] =
336+
getStartAndEndForAccess(L, PtrScev, LI->getType(), BECount, MaxBECount,
337+
&SE, nullptr, &DT, AC, LoopGuards);
350338
if (isa<SCEVCouldNotCompute>(AccessStart) ||
351339
isa<SCEVCouldNotCompute>(AccessEnd))
352340
return false;
@@ -355,10 +343,13 @@ bool llvm::isDereferenceableAndAlignedInLoop(
355343
const SCEV *PtrDiff = SE.getMinusSCEV(AccessEnd, AccessStart);
356344
if (isa<SCEVCouldNotCompute>(PtrDiff))
357345
return false;
358-
ScalarEvolution::LoopGuards LoopGuards =
359-
ScalarEvolution::LoopGuards::collect(AddRec->getLoop(), SE);
346+
347+
if (!LoopGuards)
348+
LoopGuards.emplace(
349+
ScalarEvolution::LoopGuards::collect(AddRec->getLoop(), SE));
350+
360351
APInt MaxPtrDiff =
361-
SE.getUnsignedRangeMax(SE.applyLoopGuards(PtrDiff, LoopGuards));
352+
SE.getUnsignedRangeMax(SE.applyLoopGuards(PtrDiff, *LoopGuards));
362353

363354
Value *Base = nullptr;
364355
APInt AccessSize;

llvm/lib/Analysis/LoopAccessAnalysis.cpp

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -193,30 +193,28 @@ RuntimeCheckingPtrGroup::RuntimeCheckingPtrGroup(
193193
/// Returns \p A + \p B, if it is guaranteed not to unsigned wrap. Otherwise
194194
/// return nullptr. \p A and \p B must have the same type.
195195
static const SCEV *addSCEVNoOverflow(const SCEV *A, const SCEV *B,
196-
ScalarEvolution &SE,
197-
const Instruction *CtxI) {
198-
if (!SE.willNotOverflow(Instruction::Add, /*IsSigned=*/false, A, B, CtxI))
196+
ScalarEvolution &SE) {
197+
if (!SE.willNotOverflow(Instruction::Add, /*IsSigned=*/false, A, B))
199198
return nullptr;
200199
return SE.getAddExpr(A, B);
201200
}
202201

203202
/// Returns \p A * \p B, if it is guaranteed not to unsigned wrap. Otherwise
204203
/// return nullptr. \p A and \p B must have the same type.
205204
static const SCEV *mulSCEVOverflow(const SCEV *A, const SCEV *B,
206-
ScalarEvolution &SE,
207-
const Instruction *CtxI) {
208-
if (!SE.willNotOverflow(Instruction::Mul, /*IsSigned=*/false, A, B, CtxI))
205+
ScalarEvolution &SE) {
206+
if (!SE.willNotOverflow(Instruction::Mul, /*IsSigned=*/false, A, B))
209207
return nullptr;
210208
return SE.getMulExpr(A, B);
211209
}
212210

213211
/// Return true, if evaluating \p AR at \p MaxBTC cannot wrap, because \p AR at
214212
/// \p MaxBTC is guaranteed inbounds of the accessed object.
215-
static bool
216-
evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR,
217-
const SCEV *MaxBTC, const SCEV *EltSize,
218-
ScalarEvolution &SE, const DataLayout &DL,
219-
DominatorTree *DT, AssumptionCache *AC) {
213+
static bool evaluatePtrAddRecAtMaxBTCWillNotWrap(
214+
const SCEVAddRecExpr *AR, const SCEV *MaxBTC, const SCEV *EltSize,
215+
ScalarEvolution &SE, const DataLayout &DL, DominatorTree *DT,
216+
AssumptionCache *AC,
217+
std::optional<ScalarEvolution::LoopGuards> &LoopGuards) {
220218
auto *PointerBase = SE.getPointerBase(AR->getStart());
221219
auto *StartPtr = dyn_cast<SCEVUnknown>(PointerBase);
222220
if (!StartPtr)
@@ -234,12 +232,11 @@ evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR,
234232
Type *WiderTy = SE.getWiderType(MaxBTC->getType(), Step->getType());
235233
const SCEV *DerefBytesSCEV = SE.getConstant(WiderTy, DerefBytes);
236234

237-
// Context which dominates the entire loop.
238-
auto *CtxI = L->getLoopPredecessor()->getTerminator();
239235
// Check if we have a suitable dereferencable assumption we can use.
240236
if (!StartPtrV->canBeFreed()) {
241237
RetainedKnowledge DerefRK = getKnowledgeValidInContext(
242-
StartPtrV, {Attribute::Dereferenceable}, *AC, CtxI, DT);
238+
StartPtrV, {Attribute::Dereferenceable}, *AC,
239+
L->getLoopPredecessor()->getTerminator(), DT);
243240
if (DerefRK) {
244241
DerefBytesSCEV = SE.getUMaxExpr(
245242
DerefBytesSCEV, SE.getConstant(WiderTy, DerefRK.ArgValue));
@@ -263,23 +260,36 @@ evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR,
263260
SE.getMinusSCEV(AR->getStart(), StartPtr), WiderTy);
264261

265262
const SCEV *OffsetAtLastIter =
266-
mulSCEVOverflow(MaxBTC, SE.getAbsExpr(Step, /*IsNSW=*/false), SE, CtxI);
267-
if (!OffsetAtLastIter)
268-
return false;
263+
mulSCEVOverflow(MaxBTC, SE.getAbsExpr(Step, /*IsNSW=*/false), SE);
264+
if (!OffsetAtLastIter) {
265+
// Re-try with constant max backedge-taken count if using the symbolic one
266+
// failed.
267+
MaxBTC = SE.getNoopOrZeroExtend(
268+
SE.getConstantMaxBackedgeTakenCount(AR->getLoop()), WiderTy);
269+
OffsetAtLastIter =
270+
mulSCEVOverflow(MaxBTC, SE.getAbsExpr(Step, /*IsNSW=*/false), SE);
271+
if (!OffsetAtLastIter)
272+
return false;
273+
}
269274

270275
const SCEV *OffsetEndBytes = addSCEVNoOverflow(
271-
OffsetAtLastIter, SE.getNoopOrZeroExtend(EltSize, WiderTy), SE, CtxI);
276+
OffsetAtLastIter, SE.getNoopOrZeroExtend(EltSize, WiderTy), SE);
272277
if (!OffsetEndBytes)
273278
return false;
274279

275280
if (IsKnownNonNegative) {
276281
// For positive steps, check if
277282
// (AR->getStart() - StartPtr) + (MaxBTC * Step) + EltSize <= DerefBytes,
278283
// while making sure none of the computations unsigned wrap themselves.
279-
const SCEV *EndBytes =
280-
addSCEVNoOverflow(StartOffset, OffsetEndBytes, SE, CtxI);
284+
const SCEV *EndBytes = addSCEVNoOverflow(StartOffset, OffsetEndBytes, SE);
281285
if (!EndBytes)
282286
return false;
287+
288+
if (!LoopGuards)
289+
LoopGuards.emplace(
290+
ScalarEvolution::LoopGuards::collect(AR->getLoop(), SE));
291+
292+
EndBytes = SE.applyLoopGuards(EndBytes, *LoopGuards);
283293
return SE.isKnownPredicate(CmpInst::ICMP_ULE, EndBytes, DerefBytesSCEV);
284294
}
285295

@@ -296,7 +306,8 @@ std::pair<const SCEV *, const SCEV *> llvm::getStartAndEndForAccess(
296306
const SCEV *MaxBTC, ScalarEvolution *SE,
297307
DenseMap<std::pair<const SCEV *, Type *>,
298308
std::pair<const SCEV *, const SCEV *>> *PointerBounds,
299-
DominatorTree *DT, AssumptionCache *AC) {
309+
DominatorTree *DT, AssumptionCache *AC,
310+
std::optional<ScalarEvolution::LoopGuards> &LoopGuards) {
300311
std::pair<const SCEV *, const SCEV *> *PtrBoundsPair;
301312
if (PointerBounds) {
302313
auto [Iter, Ins] = PointerBounds->insert(
@@ -332,7 +343,7 @@ std::pair<const SCEV *, const SCEV *> llvm::getStartAndEndForAccess(
332343
// separately checks that accesses cannot not wrap, so unsigned max
333344
// represents an upper bound.
334345
if (evaluatePtrAddRecAtMaxBTCWillNotWrap(AR, MaxBTC, EltSizeSCEV, *SE, DL,
335-
DT, AC)) {
346+
DT, AC, LoopGuards)) {
336347
ScEnd = AR->evaluateAtIteration(MaxBTC, *SE);
337348
} else {
338349
ScEnd = SE->getAddExpr(
@@ -381,7 +392,7 @@ void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, const SCEV *PtrExpr,
381392
const SCEV *BTC = PSE.getBackedgeTakenCount();
382393
const auto &[ScStart, ScEnd] = getStartAndEndForAccess(
383394
Lp, PtrExpr, AccessTy, BTC, SymbolicMaxBTC, PSE.getSE(),
384-
&DC.getPointerBounds(), DC.getDT(), DC.getAC());
395+
&DC.getPointerBounds(), DC.getDT(), DC.getAC(), LoopGuards);
385396
assert(!isa<SCEVCouldNotCompute>(ScStart) &&
386397
!isa<SCEVCouldNotCompute>(ScEnd) &&
387398
"must be able to compute both start and end expressions");
@@ -1987,13 +1998,13 @@ bool MemoryDepChecker::areAccessesCompletelyBeforeOrAfter(const SCEV *Src,
19871998
ScalarEvolution &SE = *PSE.getSE();
19881999
const auto &[SrcStart_, SrcEnd_] =
19892000
getStartAndEndForAccess(InnermostLoop, Src, SrcTy, BTC, SymbolicMaxBTC,
1990-
&SE, &PointerBounds, DT, AC);
2001+
&SE, &PointerBounds, DT, AC, LoopGuards);
19912002
if (isa<SCEVCouldNotCompute>(SrcStart_) || isa<SCEVCouldNotCompute>(SrcEnd_))
19922003
return false;
19932004

19942005
const auto &[SinkStart_, SinkEnd_] =
19952006
getStartAndEndForAccess(InnermostLoop, Sink, SinkTy, BTC, SymbolicMaxBTC,
1996-
&SE, &PointerBounds, DT, AC);
2007+
&SE, &PointerBounds, DT, AC, LoopGuards);
19972008
if (isa<SCEVCouldNotCompute>(SinkStart_) ||
19982009
isa<SCEVCouldNotCompute>(SinkEnd_))
19992010
return false;
@@ -3040,8 +3051,9 @@ LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE,
30403051
TTI->getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector) * 2;
30413052

30423053
DepChecker = std::make_unique<MemoryDepChecker>(
3043-
*PSE, AC, DT, L, SymbolicStrides, MaxTargetVectorWidthInBits);
3044-
PtrRtChecking = std::make_unique<RuntimePointerChecking>(*DepChecker, SE);
3054+
*PSE, AC, DT, L, SymbolicStrides, MaxTargetVectorWidthInBits, LoopGuards);
3055+
PtrRtChecking =
3056+
std::make_unique<RuntimePointerChecking>(*DepChecker, SE, LoopGuards);
30453057
if (canAnalyzeLoop())
30463058
CanVecMem = analyzeLoop(AA, LI, TLI, DT);
30473059
}

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2338,23 +2338,15 @@ bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed,
23382338
// Can we use context to prove the fact we need?
23392339
if (!CtxI)
23402340
return false;
2341+
// TODO: Support mul.
2342+
if (BinOp == Instruction::Mul)
2343+
return false;
23412344
auto *RHSC = dyn_cast<SCEVConstant>(RHS);
23422345
// TODO: Lift this limitation.
23432346
if (!RHSC)
23442347
return false;
23452348
APInt C = RHSC->getAPInt();
23462349
unsigned NumBits = C.getBitWidth();
2347-
if (BinOp == Instruction::Mul) {
2348-
// Multiplying by 0 or 1 never overflows
2349-
if (C.isZero() || C.isOne())
2350-
return true;
2351-
if (Signed)
2352-
return false;
2353-
APInt Limit = APInt::getMaxValue(NumBits).udiv(C);
2354-
// To avoid overflow, we need to make sure that LHS <= MAX / C.
2355-
return isKnownPredicateAt(ICmpInst::ICMP_ULE, LHS, getConstant(Limit),
2356-
CtxI);
2357-
}
23582350
bool IsSub = (BinOp == Instruction::Sub);
23592351
bool IsNegativeConst = (Signed && C.isNegative());
23602352
// Compute the direction and magnitude by which we need to check overflow.

llvm/test/Transforms/LoopVectorize/vect.stats.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
; RUN: opt < %s -passes=loop-vectorize -force-vector-interleave=4 -force-vector-width=4 -debug-only=loop-vectorize -enable-early-exit-vectorization -use-symbolic-maxbtc-deref-loop --disable-output -stats -S 2>&1 | FileCheck %s
1+
; RUN: opt < %s -passes=loop-vectorize -force-vector-interleave=4 -force-vector-width=4 -debug-only=loop-vectorize -enable-early-exit-vectorization --disable-output -stats -S 2>&1 | FileCheck %s
22
; REQUIRES: asserts
33

44
; We have 3 loops, two of them are vectorizable (with one being early-exit

0 commit comments

Comments
 (0)