Skip to content

Commit 1862db9

Browse files
committed
[LV] Don't vectorize if trip count expansion may introduce UB.
Introduce a utility to check if a SCEV expansion may introduce UB (couldn't find a similar utility after a quick glance) and use to the avoid vectorizing when expanding the trip count introduces UB. Fixes #89958.
1 parent 515e048 commit 1862db9

File tree

4 files changed

+220
-93
lines changed

4 files changed

+220
-93
lines changed

llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,13 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
418418
BasicBlock::iterator findInsertPointAfter(Instruction *I,
419419
Instruction *MustDominate) const;
420420

421+
/// If \p L contains exits which may execute conditionally and contain UDiv
422+
/// expressions with divisors that can be 0, expanding \p BTC may introduce
423+
/// new UB. In that case, rewrite UDiv(A, B) to UDiv(A, UMAX(1, B)). If B is
424+
/// 0, that exit cannot be taken.
425+
static const SCEV *rewriteExpressionToRemoveUB(const SCEV *BTC, Loop *L,
426+
ScalarEvolution &SE);
427+
421428
private:
422429
LLVMContext &getContext() const { return SE.getContext(); }
423430

llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2341,12 +2341,17 @@ struct SCEVFindUnsafe {
23412341
};
23422342
} // namespace
23432343

2344-
bool SCEVExpander::isSafeToExpand(const SCEV *S) const {
2344+
static bool isSafeToExpand(const SCEV *S, bool CanonicalMode,
2345+
ScalarEvolution &SE) {
23452346
SCEVFindUnsafe Search(SE, CanonicalMode);
23462347
visitAll(S, Search);
23472348
return !Search.IsUnsafe;
23482349
}
23492350

2351+
bool SCEVExpander::isSafeToExpand(const SCEV *S) const {
2352+
return ::isSafeToExpand(S, CanonicalMode, SE);
2353+
}
2354+
23502355
bool SCEVExpander::isSafeToExpandAt(const SCEV *S,
23512356
const Instruction *InsertionPoint) const {
23522357
if (!isSafeToExpand(S))
@@ -2369,6 +2374,47 @@ bool SCEVExpander::isSafeToExpandAt(const SCEV *S,
23692374
return false;
23702375
}
23712376

2377+
namespace {
2378+
struct SCEVUDivRewriter : public SCEVRewriteVisitor<SCEVUDivRewriter> {
2379+
SCEVUDivRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
2380+
2381+
static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
2382+
SCEVUDivRewriter Rewriter(SE);
2383+
return Rewriter.visit(Scev);
2384+
}
2385+
2386+
const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
2387+
auto *LHS = visit(Expr->getLHS());
2388+
auto *RHS = visit(Expr->getRHS());
2389+
return SE.getUDivExpr(LHS, SE.getUMaxExpr(SE.getOne(RHS->getType()), RHS));
2390+
}
2391+
};
2392+
} // namespace
2393+
2394+
const SCEV *SCEVExpander::rewriteExpressionToRemoveUB(const SCEV *Expr, Loop *L,
2395+
ScalarEvolution &SE) {
2396+
SmallVector<BasicBlock *> Exiting;
2397+
L->getExitingBlocks(Exiting);
2398+
// Check if exit count for any exit that may execute unconditionally may in
2399+
// introduce UB. Note that we can skip checks in the header or if there's a
2400+
// single exit, as in those cases we know that the exit count will be
2401+
// evaluated in each loop iteration. There are other cases where the exiting
2402+
// block executes on each loop iteration, but we don't have a cheap way to
2403+
// check at the moment.
2404+
2405+
if (Exiting.size() == 1 || all_of(Exiting, [L, &SE](BasicBlock *E) {
2406+
if (L->getHeader() == E)
2407+
return true;
2408+
const SCEV *EC = SE.getExitCount(L, E);
2409+
if (isa<SCEVCouldNotCompute>(EC))
2410+
return true;
2411+
return ::isSafeToExpand(EC, true, SE);
2412+
}))
2413+
return Expr;
2414+
2415+
return SCEVUDivRewriter::rewrite(Expr, SE);
2416+
}
2417+
23722418
void SCEVExpanderCleaner::cleanup() {
23732419
// Result is used, nothing to remove.
23742420
if (ResultUsed)

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,9 @@ const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE,
984984
assert(!isa<SCEVCouldNotCompute>(BackedgeTakenCount) && "Invalid loop count");
985985

986986
ScalarEvolution &SE = *PSE.getSE();
987+
if (OrigLoop)
988+
BackedgeTakenCount = SCEVExpander::rewriteExpressionToRemoveUB(
989+
BackedgeTakenCount, OrigLoop, SE);
987990
return SE.getTripCountFromExitCount(BackedgeTakenCount, IdxTy, OrigLoop);
988991
}
989992

0 commit comments

Comments
 (0)