diff --git a/llvm/lib/Analysis/Loads.cpp b/llvm/lib/Analysis/Loads.cpp index 9a2c9ba63ec7e..7a8fbbd0fb919 100644 --- a/llvm/lib/Analysis/Loads.cpp +++ b/llvm/lib/Analysis/Loads.cpp @@ -26,6 +26,10 @@ using namespace llvm; +static cl::opt + UseSymbolicMaxBTCForDerefInLoop("use-symbolic-maxbtc-deref-loop", + cl::init(false)); + static bool isAligned(const Value *Base, Align Alignment, const DataLayout &DL) { return Base->getPointerAlignment(DL) >= Alignment; @@ -332,7 +336,7 @@ bool llvm::isDereferenceableAndAlignedInLoop( if (isa(MaxBECount)) return false; - if (isa(BECount)) { + if (isa(BECount) && !UseSymbolicMaxBTCForDerefInLoop) { // TODO: Support symbolic max backedge taken counts for loops without // computable backedge taken counts. MaxBECount = @@ -340,6 +344,7 @@ bool llvm::isDereferenceableAndAlignedInLoop( ? SE.getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates) : SE.getConstantMaxBackedgeTakenCount(L); } + const auto &[AccessStart, AccessEnd] = getStartAndEndForAccess( L, PtrScev, LI->getType(), BECount, MaxBECount, &SE, nullptr, &DT, AC); if (isa(AccessStart) || diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp index bceddd0325276..258fa982ed1d0 100644 --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -193,8 +193,9 @@ RuntimeCheckingPtrGroup::RuntimeCheckingPtrGroup( /// Returns \p A + \p B, if it is guaranteed not to unsigned wrap. Otherwise /// return nullptr. \p A and \p B must have the same type. static const SCEV *addSCEVNoOverflow(const SCEV *A, const SCEV *B, - ScalarEvolution &SE) { - if (!SE.willNotOverflow(Instruction::Add, /*IsSigned=*/false, A, B)) + ScalarEvolution &SE, + const Instruction *CtxI) { + if (!SE.willNotOverflow(Instruction::Add, /*IsSigned=*/false, A, B, CtxI)) return nullptr; return SE.getAddExpr(A, B); } @@ -202,8 +203,9 @@ static const SCEV *addSCEVNoOverflow(const SCEV *A, const SCEV *B, /// Returns \p A * \p B, if it is guaranteed not to unsigned wrap. Otherwise /// return nullptr. \p A and \p B must have the same type. static const SCEV *mulSCEVOverflow(const SCEV *A, const SCEV *B, - ScalarEvolution &SE) { - if (!SE.willNotOverflow(Instruction::Mul, /*IsSigned=*/false, A, B)) + ScalarEvolution &SE, + const Instruction *CtxI) { + if (!SE.willNotOverflow(Instruction::Mul, /*IsSigned=*/false, A, B, CtxI)) return nullptr; return SE.getMulExpr(A, B); } @@ -232,11 +234,12 @@ evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR, Type *WiderTy = SE.getWiderType(MaxBTC->getType(), Step->getType()); const SCEV *DerefBytesSCEV = SE.getConstant(WiderTy, DerefBytes); + // Context which dominates the entire loop. + auto *CtxI = L->getLoopPredecessor()->getTerminator(); // Check if we have a suitable dereferencable assumption we can use. if (!StartPtrV->canBeFreed()) { RetainedKnowledge DerefRK = getKnowledgeValidInContext( - StartPtrV, {Attribute::Dereferenceable}, *AC, - L->getLoopPredecessor()->getTerminator(), DT); + StartPtrV, {Attribute::Dereferenceable}, *AC, CtxI, DT); if (DerefRK) { DerefBytesSCEV = SE.getUMaxExpr( DerefBytesSCEV, SE.getConstant(WiderTy, DerefRK.ArgValue)); @@ -260,12 +263,12 @@ evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR, SE.getMinusSCEV(AR->getStart(), StartPtr), WiderTy); const SCEV *OffsetAtLastIter = - mulSCEVOverflow(MaxBTC, SE.getAbsExpr(Step, /*IsNSW=*/false), SE); + mulSCEVOverflow(MaxBTC, SE.getAbsExpr(Step, /*IsNSW=*/false), SE, CtxI); if (!OffsetAtLastIter) return false; const SCEV *OffsetEndBytes = addSCEVNoOverflow( - OffsetAtLastIter, SE.getNoopOrZeroExtend(EltSize, WiderTy), SE); + OffsetAtLastIter, SE.getNoopOrZeroExtend(EltSize, WiderTy), SE, CtxI); if (!OffsetEndBytes) return false; @@ -273,7 +276,8 @@ evaluatePtrAddRecAtMaxBTCWillNotWrap(const SCEVAddRecExpr *AR, // For positive steps, check if // (AR->getStart() - StartPtr) + (MaxBTC * Step) + EltSize <= DerefBytes, // while making sure none of the computations unsigned wrap themselves. - const SCEV *EndBytes = addSCEVNoOverflow(StartOffset, OffsetEndBytes, SE); + const SCEV *EndBytes = + addSCEVNoOverflow(StartOffset, OffsetEndBytes, SE, CtxI); if (!EndBytes) return false; return SE.isKnownPredicate(CmpInst::ICMP_ULE, EndBytes, DerefBytesSCEV); diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index e5e4dcfce3bd4..aa2bcf7917537 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -2338,15 +2338,23 @@ bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, // Can we use context to prove the fact we need? if (!CtxI) return false; - // TODO: Support mul. - if (BinOp == Instruction::Mul) - return false; auto *RHSC = dyn_cast(RHS); // TODO: Lift this limitation. if (!RHSC) return false; APInt C = RHSC->getAPInt(); unsigned NumBits = C.getBitWidth(); + if (BinOp == Instruction::Mul) { + // Multiplying by 0 or 1 never overflows + if (C.isZero() || C.isOne()) + return true; + if (Signed) + return false; + APInt Limit = APInt::getMaxValue(NumBits).udiv(C); + // To avoid overflow, we need to make sure that LHS <= MAX / C. + return isKnownPredicateAt(ICmpInst::ICMP_ULE, LHS, getConstant(Limit), + CtxI); + } bool IsSub = (BinOp == Instruction::Sub); bool IsNegativeConst = (Signed && C.isNegative()); // Compute the direction and magnitude by which we need to check overflow. diff --git a/llvm/test/Transforms/LoopVectorize/vect.stats.ll b/llvm/test/Transforms/LoopVectorize/vect.stats.ll index e3240c8181519..f3695e6712952 100644 --- a/llvm/test/Transforms/LoopVectorize/vect.stats.ll +++ b/llvm/test/Transforms/LoopVectorize/vect.stats.ll @@ -1,4 +1,4 @@ -; RUN: opt < %s -passes=loop-vectorize -force-vector-interleave=4 -force-vector-width=4 -debug-only=loop-vectorize -enable-early-exit-vectorization --disable-output -stats -S 2>&1 | FileCheck %s +; RUN: opt < %s -passes=loop-vectorize -force-vector-interleave=4 -force-vector-width=4 -debug-only=loop-vectorize -enable-early-exit-vectorization -use-symbolic-maxbtc-deref-loop --disable-output -stats -S 2>&1 | FileCheck %s ; REQUIRES: asserts ; We have 3 loops, two of them are vectorizable (with one being early-exit