Skip to content

Commit 00926a6

Browse files
authored
[SCEV][LAA] Support multiplication overflow computation (#155236)
Add support for identifying multiplication overflow in SCEV. This is needed in LoopAccessAnalysis and that limitation was worked around by 484417a. This allows early-exit vectorization to work as expected in vect.stats.ll test without needing the workaround.
1 parent 1eb6c2b commit 00926a6

File tree

4 files changed

+31
-14
lines changed

4 files changed

+31
-14
lines changed

llvm/lib/Analysis/Loads.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626

2727
using namespace llvm;
2828

29+
static cl::opt<bool>
30+
UseSymbolicMaxBTCForDerefInLoop("use-symbolic-maxbtc-deref-loop",
31+
cl::init(false));
32+
2933
static bool isAligned(const Value *Base, Align Alignment,
3034
const DataLayout &DL) {
3135
return Base->getPointerAlignment(DL) >= Alignment;
@@ -332,14 +336,15 @@ bool llvm::isDereferenceableAndAlignedInLoop(
332336
if (isa<SCEVCouldNotCompute>(MaxBECount))
333337
return false;
334338

335-
if (isa<SCEVCouldNotCompute>(BECount)) {
339+
if (isa<SCEVCouldNotCompute>(BECount) && !UseSymbolicMaxBTCForDerefInLoop) {
336340
// TODO: Support symbolic max backedge taken counts for loops without
337341
// computable backedge taken counts.
338342
MaxBECount =
339343
Predicates
340344
? SE.getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
341345
: SE.getConstantMaxBackedgeTakenCount(L);
342346
}
347+
343348
const auto &[AccessStart, AccessEnd] = getStartAndEndForAccess(
344349
L, PtrScev, LI->getType(), BECount, MaxBECount, &SE, nullptr, &DT, AC);
345350
if (isa<SCEVCouldNotCompute>(AccessStart) ||

llvm/lib/Analysis/LoopAccessAnalysis.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -193,17 +193,19 @@ RuntimeCheckingPtrGroup::RuntimeCheckingPtrGroup(
193193
/// Returns \p A + \p B, if it is guaranteed not to unsigned wrap. Otherwise
194194
/// return nullptr. \p A and \p B must have the same type.
195195
static const SCEV *addSCEVNoOverflow(const SCEV *A, const SCEV *B,
196-
ScalarEvolution &SE) {
197-
if (!SE.willNotOverflow(Instruction::Add, /*IsSigned=*/false, A, B))
196+
ScalarEvolution &SE,
197+
const Instruction *CtxI) {
198+
if (!SE.willNotOverflow(Instruction::Add, /*IsSigned=*/false, A, B, CtxI))
198199
return nullptr;
199200
return SE.getAddExpr(A, B);
200201
}
201202

202203
/// Returns \p A * \p B, if it is guaranteed not to unsigned wrap. Otherwise
203204
/// return nullptr. \p A and \p B must have the same type.
204205
static const SCEV *mulSCEVOverflow(const SCEV *A, const SCEV *B,
205-
ScalarEvolution &SE) {
206-
if (!SE.willNotOverflow(Instruction::Mul, /*IsSigned=*/false, A, B))
206+
ScalarEvolution &SE,
207+
const Instruction *CtxI) {
208+
if (!SE.willNotOverflow(Instruction::Mul, /*IsSigned=*/false, A, B, CtxI))
207209
return nullptr;
208210
return SE.getMulExpr(A, B);
209211
}
@@ -232,11 +234,12 @@ evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR,
232234
Type *WiderTy = SE.getWiderType(MaxBTC->getType(), Step->getType());
233235
const SCEV *DerefBytesSCEV = SE.getConstant(WiderTy, DerefBytes);
234236

237+
// Context which dominates the entire loop.
238+
auto *CtxI = L->getLoopPredecessor()->getTerminator();
235239
// Check if we have a suitable dereferencable assumption we can use.
236240
if (!StartPtrV->canBeFreed()) {
237241
RetainedKnowledge DerefRK = getKnowledgeValidInContext(
238-
StartPtrV, {Attribute::Dereferenceable}, *AC,
239-
L->getLoopPredecessor()->getTerminator(), DT);
242+
StartPtrV, {Attribute::Dereferenceable}, *AC, CtxI, DT);
240243
if (DerefRK) {
241244
DerefBytesSCEV = SE.getUMaxExpr(
242245
DerefBytesSCEV, SE.getConstant(WiderTy, DerefRK.ArgValue));
@@ -260,20 +263,21 @@ evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR,
260263
SE.getMinusSCEV(AR->getStart(), StartPtr), WiderTy);
261264

262265
const SCEV *OffsetAtLastIter =
263-
mulSCEVOverflow(MaxBTC, SE.getAbsExpr(Step, /*IsNSW=*/false), SE);
266+
mulSCEVOverflow(MaxBTC, SE.getAbsExpr(Step, /*IsNSW=*/false), SE, CtxI);
264267
if (!OffsetAtLastIter)
265268
return false;
266269

267270
const SCEV *OffsetEndBytes = addSCEVNoOverflow(
268-
OffsetAtLastIter, SE.getNoopOrZeroExtend(EltSize, WiderTy), SE);
271+
OffsetAtLastIter, SE.getNoopOrZeroExtend(EltSize, WiderTy), SE, CtxI);
269272
if (!OffsetEndBytes)
270273
return false;
271274

272275
if (IsKnownNonNegative) {
273276
// For positive steps, check if
274277
// (AR->getStart() - StartPtr) + (MaxBTC * Step) + EltSize <= DerefBytes,
275278
// while making sure none of the computations unsigned wrap themselves.
276-
const SCEV *EndBytes = addSCEVNoOverflow(StartOffset, OffsetEndBytes, SE);
279+
const SCEV *EndBytes =
280+
addSCEVNoOverflow(StartOffset, OffsetEndBytes, SE, CtxI);
277281
if (!EndBytes)
278282
return false;
279283
return SE.isKnownPredicate(CmpInst::ICMP_ULE, EndBytes, DerefBytesSCEV);

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2338,15 +2338,23 @@ bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed,
23382338
// Can we use context to prove the fact we need?
23392339
if (!CtxI)
23402340
return false;
2341-
// TODO: Support mul.
2342-
if (BinOp == Instruction::Mul)
2343-
return false;
23442341
auto *RHSC = dyn_cast<SCEVConstant>(RHS);
23452342
// TODO: Lift this limitation.
23462343
if (!RHSC)
23472344
return false;
23482345
APInt C = RHSC->getAPInt();
23492346
unsigned NumBits = C.getBitWidth();
2347+
if (BinOp == Instruction::Mul) {
2348+
// Multiplying by 0 or 1 never overflows
2349+
if (C.isZero() || C.isOne())
2350+
return true;
2351+
if (Signed)
2352+
return false;
2353+
APInt Limit = APInt::getMaxValue(NumBits).udiv(C);
2354+
// To avoid overflow, we need to make sure that LHS <= MAX / C.
2355+
return isKnownPredicateAt(ICmpInst::ICMP_ULE, LHS, getConstant(Limit),
2356+
CtxI);
2357+
}
23502358
bool IsSub = (BinOp == Instruction::Sub);
23512359
bool IsNegativeConst = (Signed && C.isNegative());
23522360
// Compute the direction and magnitude by which we need to check overflow.

llvm/test/Transforms/LoopVectorize/vect.stats.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
; RUN: opt < %s -passes=loop-vectorize -force-vector-interleave=4 -force-vector-width=4 -debug-only=loop-vectorize -enable-early-exit-vectorization --disable-output -stats -S 2>&1 | FileCheck %s
1+
; RUN: opt < %s -passes=loop-vectorize -force-vector-interleave=4 -force-vector-width=4 -debug-only=loop-vectorize -enable-early-exit-vectorization -use-symbolic-maxbtc-deref-loop --disable-output -stats -S 2>&1 | FileCheck %s
22
; REQUIRES: asserts
33

44
; We have 3 loops, two of them are vectorizable (with one being early-exit

0 commit comments

Comments
 (0)