Skip to content

Commit a59e11d

Browse files
committed
[LV] Don't vectorize if trip count expansion may introduce UB.
Introduce a utility to check if a SCEV expansion may introduce UB (couldn't find a similar utility after a quick glance) and use to the avoid vectorizing when expanding the trip count introduces UB. Fixes #89958.
1 parent 515e048 commit a59e11d

File tree

4 files changed

+219
-92
lines changed

4 files changed

+219
-92
lines changed

llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,13 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
418418
BasicBlock::iterator findInsertPointAfter(Instruction *I,
419419
Instruction *MustDominate) const;
420420

421+
/// If \p L contains exits which may execute conditionally and contain UDiv
422+
/// expressions with divisors that can be 0, expanding \p BTC may introduce
423+
/// new UB. In that case, rewrite UDiv(A, B) to UDiv(A, UMAX(1, B)). If B is
424+
/// 0, that exit cannot be taken.
425+
static const SCEV *rewriteExpressionToRemoveUB(const SCEV *BTC, Loop *L,
426+
ScalarEvolution &SE);
427+
421428
private:
422429
LLVMContext &getContext() const { return SE.getContext(); }
423430

llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,52 @@ SCEVExpander::findInsertPointAfter(Instruction *I,
161161
return IP;
162162
}
163163

164+
namespace {
165+
struct SCEVUDivRewriter : public SCEVRewriteVisitor<SCEVUDivRewriter> {
166+
SCEVUDivRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
167+
168+
static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
169+
SCEVUDivRewriter Rewriter(SE);
170+
return Rewriter.visit(Scev);
171+
}
172+
173+
const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
174+
auto *LHS = visit(Expr->getLHS());
175+
auto *RHS = visit(Expr->getRHS());
176+
return SE.getUDivExpr(LHS, SE.getUMaxExpr(SE.getOne(RHS->getType()), RHS));
177+
}
178+
};
179+
} // namespace
180+
181+
const SCEV *SCEVExpander::rewriteExpressionToRemoveUB(const SCEV *Expr, Loop *L,
182+
ScalarEvolution &SE) {
183+
SmallVector<BasicBlock *> Exiting;
184+
L->getExitingBlocks(Exiting);
185+
// Check if exit count for any exit that may execute unconditionally may in
186+
// introduce UB. Note that we can skip checks in the header or if there's a
187+
// single exit, as in those cases we know that the exit count will be
188+
// evaluated in each loop iteration. There are other cases where the exiting
189+
// block executes on each loop iteration, but we don't have a cheap way to
190+
// check at the moment.
191+
192+
if (Exiting.size() > 1 && any_of(Exiting, [L, &SE](BasicBlock *E) {
193+
if (L->getHeader() == E)
194+
return false;
195+
const SCEV *EC = SE.getExitCount(L, E);
196+
return !isa<SCEVCouldNotCompute>(EC) &&
197+
SCEVExprContains(EC, [&SE](const SCEV *Op) {
198+
auto *UDiv = dyn_cast<SCEVUDivExpr>(Op);
199+
if (!UDiv)
200+
return false;
201+
202+
return !SE.isKnownNonZero(UDiv->getOperand(1));
203+
});
204+
})) {
205+
return SCEVUDivRewriter::rewrite(Expr, SE);
206+
}
207+
return Expr;
208+
}
209+
164210
BasicBlock::iterator
165211
SCEVExpander::GetOptimalInsertionPointForCastOf(Value *V) const {
166212
// Cast the argument at the beginning of the entry block, after

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -984,6 +984,9 @@ const SCEV *createTripCountSCEV(Type *IdxTy, PredicatedScalarEvolution &PSE,
984984
assert(!isa<SCEVCouldNotCompute>(BackedgeTakenCount) && "Invalid loop count");
985985

986986
ScalarEvolution &SE = *PSE.getSE();
987+
if (OrigLoop)
988+
BackedgeTakenCount = SCEVExpander::rewriteExpressionToRemoveUB(
989+
BackedgeTakenCount, OrigLoop, SE);
987990
return SE.getTripCountFromExitCount(BackedgeTakenCount, IdxTy, OrigLoop);
988991
}
989992

0 commit comments

Comments
 (0)