Skip to content

Commit bff4e5a

Browse files
committed
[LV] Don't vectorize if trip count expansion may introduce UB.
Introduce a utility to check if a SCEV expansion may introduce UB (couldn't find a similar utility after a quick glance) and use to the avoid vectorizing when expanding the trip count introduces UB. Fixes #89958. !fixup introduce SafeUDivMode to SCEVExpander. Step
1 parent 766ec7d commit bff4e5a

File tree

4 files changed

+61
-14
lines changed

4 files changed

+61
-14
lines changed

llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,11 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
124124
/// "expanded" form.
125125
bool LSRMode;
126126

127+
/// When true, rewrite any divisors of UDiv expressions that may be 0 to
128+
/// umax(Divisor, 1) to avoid introducing UB. If the divisor may be poison,
129+
/// freeze it first.
130+
bool SafeUDivMode = false;
131+
127132
typedef IRBuilder<InstSimplifyFolder, IRBuilderCallbackInserter> BuilderType;
128133
BuilderType Builder;
129134

@@ -300,6 +305,9 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
300305
/// location and their operands are defined at this location.
301306
bool isSafeToExpandAt(const SCEV *S, const Instruction *InsertionPoint) const;
302307

308+
static bool isSafeToExpand(const SCEV *S, bool CanonicalMode,
309+
ScalarEvolution &SE);
310+
303311
/// Insert code to directly compute the specified SCEV expression into the
304312
/// program. The code is inserted into the specified block.
305313
Value *expandCodeFor(const SCEV *SH, Type *Ty, BasicBlock::iterator I);
@@ -418,6 +426,11 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
418426
BasicBlock::iterator findInsertPointAfter(Instruction *I,
419427
Instruction *MustDominate) const;
420428

429+
static const SCEV *rewriteExpressionToRemoveUB(const SCEV *BTC, Loop *L,
430+
ScalarEvolution &SE);
431+
432+
void setSafeUDivMode() { SafeUDivMode = true; }
433+
421434
private:
422435
LLVMContext &getContext() const { return SE.getContext(); }
423436

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4304,12 +4304,19 @@ ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind,
43044304
}
43054305

43064306
for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4307+
bool MayBeUB = SCEVExprContains(Ops[i], [](const SCEV *S) {
4308+
auto *UDiv = dyn_cast<SCEVUDivExpr>(S);
4309+
return UDiv && !isa<SCEVConstant>(UDiv->getOperand(1));
4310+
});
4311+
4312+
if (MayBeUB)
4313+
continue;
43074314
// We can replace %x umin_seq %y with %x umin %y if either:
43084315
// * %y being poison implies %x is also poison.
43094316
// * %x cannot be the saturating value (e.g. zero for umin).
43104317
if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4311-
isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4312-
SaturationPoint)) {
4318+
(isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4319+
SaturationPoint))) {
43134320
SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
43144321
Ops[i - 1] = getMinMaxExpr(
43154322
SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(Kind),

llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,12 @@ Value *SCEVExpander::visitUDivExpr(const SCEVUDivExpr *S) {
665665
}
666666

667667
Value *RHS = expand(S->getRHS());
668+
if (SafeUDivMode && !SE.isKnownNonZero(S->getRHS())) {
669+
if (!isa<SCEVConstant>(S->getRHS()))
670+
RHS = Builder.CreateFreeze(RHS);
671+
RHS = Builder.CreateIntrinsic(RHS->getType(), Intrinsic::umax,
672+
{RHS, ConstantInt::get(RHS->getType(), 1)});
673+
}
668674
return InsertBinop(Instruction::UDiv, LHS, RHS, SCEV::FlagAnyWrap,
669675
/*IsSafeToHoist*/ SE.isKnownNonZero(S->getRHS()));
670676
}
@@ -1358,11 +1364,13 @@ Value *SCEVExpander::visitSignExtendExpr(const SCEVSignExtendExpr *S) {
13581364
Value *SCEVExpander::expandMinMaxExpr(const SCEVNAryExpr *S,
13591365
Intrinsic::ID IntrinID, Twine Name,
13601366
bool IsSequential) {
1367+
SafeUDivMode = true;
13611368
Value *LHS = expand(S->getOperand(S->getNumOperands() - 1));
13621369
Type *Ty = LHS->getType();
13631370
if (IsSequential)
13641371
LHS = Builder.CreateFreeze(LHS);
13651372
for (int i = S->getNumOperands() - 2; i >= 0; --i) {
1373+
SafeUDivMode = i != 0;
13661374
Value *RHS = expand(S->getOperand(i));
13671375
if (IsSequential && i != 0)
13681376
RHS = Builder.CreateFreeze(RHS);
@@ -2315,12 +2323,17 @@ struct SCEVFindUnsafe {
23152323
};
23162324
} // namespace
23172325

2318-
bool SCEVExpander::isSafeToExpand(const SCEV *S) const {
2326+
bool SCEVExpander::isSafeToExpand(const SCEV *S, bool CanonicalMode,
2327+
ScalarEvolution &SE) {
23192328
SCEVFindUnsafe Search(SE, CanonicalMode);
23202329
visitAll(S, Search);
23212330
return !Search.IsUnsafe;
23222331
}
23232332

2333+
bool SCEVExpander::isSafeToExpand(const SCEV *S) const {
2334+
return isSafeToExpand(S, CanonicalMode, SE);
2335+
}
2336+
23242337
bool SCEVExpander::isSafeToExpandAt(const SCEV *S,
23252338
const Instruction *InsertionPoint) const {
23262339
if (!isSafeToExpand(S))

llvm/test/Transforms/LoopVectorize/trip-count-expansion-may-introduce-ub.ll

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -463,9 +463,12 @@ define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch(ptr %dst, i64 %N
463463
; CHECK-LABEL: define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch(
464464
; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]]) {
465465
; CHECK-NEXT: entry:
466+
; CHECK-NEXT: [[TMP9:%.*]] = freeze i64 [[N]]
467+
; CHECK-NEXT: [[TMP10:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP9]], i64 1)
468+
; CHECK-NEXT: [[TMP0:%.*]] = udiv i64 42, [[TMP10]]
469+
; CHECK-NEXT: [[TMP8:%.*]] = freeze i64 [[TMP0]]
466470
; CHECK-NEXT: [[SMAX:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
467-
; CHECK-NEXT: [[TMP0:%.*]] = udiv i64 42, [[N]]
468-
; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[SMAX]], i64 [[TMP0]])
471+
; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP8]], i64 [[SMAX]])
469472
; CHECK-NEXT: [[TMP1:%.*]] = add nuw nsw i64 [[UMIN]], 1
470473
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ule i64 [[TMP1]], 4
471474
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
@@ -529,7 +532,9 @@ define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch_different_bounds
529532
; CHECK-LABEL: define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch_different_bounds(
530533
; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]], i64 [[M:%.*]]) {
531534
; CHECK-NEXT: entry:
532-
; CHECK-NEXT: [[TMP2:%.*]] = udiv i64 42, [[M]]
535+
; CHECK-NEXT: [[TMP0:%.*]] = freeze i64 [[M]]
536+
; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP0]], i64 1)
537+
; CHECK-NEXT: [[TMP2:%.*]] = udiv i64 42, [[TMP1]]
533538
; CHECK-NEXT: [[TMP3:%.*]] = freeze i64 [[TMP2]]
534539
; CHECK-NEXT: [[SMAX:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
535540
; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP3]], i64 [[SMAX]])
@@ -598,9 +603,12 @@ define i64 @multi_exit_4_exit_count_with_udiv_by_frozen_value_in_latch(ptr %dst,
598603
; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]]) {
599604
; CHECK-NEXT: entry:
600605
; CHECK-NEXT: [[FR_N:%.*]] = freeze i64 [[N]]
601-
; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
602-
; CHECK-NEXT: [[TMP0:%.*]] = udiv i64 42, [[FR_N]]
603-
; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP2]], i64 [[TMP0]])
606+
; CHECK-NEXT: [[TMP0:%.*]] = freeze i64 [[FR_N]]
607+
; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP0]], i64 1)
608+
; CHECK-NEXT: [[TMP2:%.*]] = udiv i64 42, [[TMP1]]
609+
; CHECK-NEXT: [[TMP10:%.*]] = freeze i64 [[TMP2]]
610+
; CHECK-NEXT: [[SMAX:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
611+
; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP10]], i64 [[SMAX]])
604612
; CHECK-NEXT: [[TMP3:%.*]] = add nuw nsw i64 [[UMIN]], 1
605613
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ule i64 [[TMP3]], 4
606614
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
@@ -786,12 +794,15 @@ define i64 @multi_exit_4_exit_count_with_urem_by_value_in_latch(ptr %dst, i64 %N
786794
; CHECK-LABEL: define i64 @multi_exit_4_exit_count_with_urem_by_value_in_latch(
787795
; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]]) {
788796
; CHECK-NEXT: entry:
789-
; CHECK-NEXT: [[SMAX:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
790-
; CHECK-NEXT: [[TMP0:%.*]] = udiv i64 42, [[N]]
797+
; CHECK-NEXT: [[TMP11:%.*]] = freeze i64 [[N]]
798+
; CHECK-NEXT: [[TMP12:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP11]], i64 1)
799+
; CHECK-NEXT: [[TMP0:%.*]] = udiv i64 42, [[TMP12]]
791800
; CHECK-NEXT: [[TMP1:%.*]] = mul nuw i64 [[N]], [[TMP0]]
792801
; CHECK-NEXT: [[TMP2:%.*]] = sub i64 42, [[TMP1]]
793802
; CHECK-NEXT: [[SMAX1:%.*]] = call i64 @llvm.smax.i64(i64 [[TMP2]], i64 0)
794-
; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[SMAX]], i64 [[SMAX1]])
803+
; CHECK-NEXT: [[TMP10:%.*]] = freeze i64 [[SMAX1]]
804+
; CHECK-NEXT: [[SMAX2:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
805+
; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP10]], i64 [[SMAX2]])
795806
; CHECK-NEXT: [[TMP3:%.*]] = add nuw i64 [[UMIN]], 1
796807
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ule i64 [[TMP3]], 4
797808
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
@@ -1004,9 +1015,12 @@ define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch1(ptr %dst, i64 %
10041015
; CHECK-LABEL: define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch1(
10051016
; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]]) {
10061017
; CHECK-NEXT: entry:
1018+
; CHECK-NEXT: [[TMP0:%.*]] = freeze i64 [[N]]
1019+
; CHECK-NEXT: [[TMP8:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP0]], i64 1)
1020+
; CHECK-NEXT: [[TMP9:%.*]] = udiv i64 42, [[TMP8]]
1021+
; CHECK-NEXT: [[TMP10:%.*]] = freeze i64 [[TMP9]]
10071022
; CHECK-NEXT: [[SMAX:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
1008-
; CHECK-NEXT: [[TMP0:%.*]] = udiv i64 42, [[N]]
1009-
; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[SMAX]], i64 [[TMP0]])
1023+
; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP10]], i64 [[SMAX]])
10101024
; CHECK-NEXT: [[TMP1:%.*]] = add nuw nsw i64 [[UMIN]], 1
10111025
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ule i64 [[TMP1]], 4
10121026
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]

0 commit comments

Comments
 (0)