Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
/// "expanded" form.
bool LSRMode;

/// When true, rewrite any divisors of UDiv expressions that may be 0 to
/// umax(Divisor, 1) to avoid introducing UB. If the divisor may be poison,
/// freeze it first.
bool SafeUDivMode = false;

typedef IRBuilder<InstSimplifyFolder, IRBuilderCallbackInserter> BuilderType;
BuilderType Builder;

Expand Down Expand Up @@ -419,6 +424,9 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
BasicBlock::iterator findInsertPointAfter(Instruction *I,
Instruction *MustDominate) const;

static const SCEV *rewriteExpressionToRemoveUB(const SCEV *BTC, Loop *L,
ScalarEvolution &SE);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused prototype.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dropped, thanks!


private:
LLVMContext &getContext() const { return SE.getContext(); }

Expand Down
14 changes: 13 additions & 1 deletion llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,15 @@ Value *SCEVExpander::visitUDivExpr(const SCEVUDivExpr *S) {
SCEV::FlagAnyWrap, /*IsSafeToHoist*/ true);
}

Value *RHS = expand(S->getRHS());
const SCEV *RHSExpr = S->getRHS();
Value *RHS = expand(RHSExpr);
if (SafeUDivMode &&
(!isa<SCEVConstant>(RHSExpr) || SE.isKnownNonZero(RHSExpr))) {
if (!isa<SCEVConstant>(S->getRHS()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can now use the isGuaranteedNotToBePoison check instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated. This removes the freezes in a few cases, but that should be fine as in those cases the function would always trigger UB, so expanding the UDiv before the loop simply triggers UB earlier (e.g. multi_exit_4_exit_count_with_udiv_by_value_in_latch). This should be fine I think (https://alive2.llvm.org/ce/z/anDrW9).

Added additional variants where this is not the case (multi_exit_4_exit_count_with_udiv_by_value_in_latch_call_before_loop, multi_exit_4_exit_count_with_udiv_by_value_in_latch_loop_may_not_execute) and there's also multi_exit_4_exit_count_with_udiv_by_value_in_latch_different_bounds

RHS = Builder.CreateFreeze(RHS);
RHS = Builder.CreateIntrinsic(RHS->getType(), Intrinsic::umax,
{RHS, ConstantInt::get(RHS->getType(), 1)});
}
return InsertBinop(Instruction::UDiv, LHS, RHS, SCEV::FlagAnyWrap,
/*IsSafeToHoist*/ SE.isKnownNonZero(S->getRHS()));
}
Expand Down Expand Up @@ -1371,11 +1379,14 @@ Value *SCEVExpander::visitSignExtendExpr(const SCEVSignExtendExpr *S) {
Value *SCEVExpander::expandMinMaxExpr(const SCEVNAryExpr *S,
Intrinsic::ID IntrinID, Twine Name,
bool IsSequential) {
bool PrevSafeMode = SafeUDivMode;
SafeUDivMode |= IsSequential;
Value *LHS = expand(S->getOperand(S->getNumOperands() - 1));
Type *Ty = LHS->getType();
if (IsSequential)
LHS = Builder.CreateFreeze(LHS);
for (int i = S->getNumOperands() - 2; i >= 0; --i) {
SafeUDivMode = (IsSequential && i != 0) || PrevSafeMode;
Value *RHS = expand(S->getOperand(i));
if (IsSequential && i != 0)
RHS = Builder.CreateFreeze(RHS);
Expand All @@ -1390,6 +1401,7 @@ Value *SCEVExpander::expandMinMaxExpr(const SCEVNAryExpr *S,
}
LHS = Sel;
}
SafeUDivMode = PrevSafeMode;
return LHS;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,9 @@ define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch(ptr %dst, i64 %N
; CHECK-LABEL: define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch(
; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]]) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = udiv i64 42, [[N]]
; CHECK-NEXT: [[TMP9:%.*]] = freeze i64 [[N]]
; CHECK-NEXT: [[TMP10:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP9]], i64 1)
; CHECK-NEXT: [[TMP0:%.*]] = udiv i64 42, [[TMP10]]
; CHECK-NEXT: [[TMP8:%.*]] = freeze i64 [[TMP0]]
; CHECK-NEXT: [[SMAX:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP8]], i64 [[SMAX]])
Expand Down Expand Up @@ -530,7 +532,9 @@ define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch_different_bounds
; CHECK-LABEL: define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch_different_bounds(
; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]], i64 [[M:%.*]]) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP2:%.*]] = udiv i64 42, [[M]]
; CHECK-NEXT: [[TMP0:%.*]] = freeze i64 [[M]]
; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP0]], i64 1)
; CHECK-NEXT: [[TMP2:%.*]] = udiv i64 42, [[TMP1]]
; CHECK-NEXT: [[TMP3:%.*]] = freeze i64 [[TMP2]]
; CHECK-NEXT: [[SMAX:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP3]], i64 [[SMAX]])
Expand Down Expand Up @@ -599,10 +603,12 @@ define i64 @multi_exit_4_exit_count_with_udiv_by_frozen_value_in_latch(ptr %dst,
; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]]) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[FR_N:%.*]] = freeze i64 [[N]]
; CHECK-NEXT: [[TMP0:%.*]] = udiv i64 42, [[FR_N]]
; CHECK-NEXT: [[TMP1:%.*]] = freeze i64 [[TMP0]]
; CHECK-NEXT: [[TMP0:%.*]] = freeze i64 [[FR_N]]
; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP0]], i64 1)
; CHECK-NEXT: [[TMP2:%.*]] = udiv i64 42, [[TMP1]]
; CHECK-NEXT: [[TMP10:%.*]] = freeze i64 [[TMP2]]
; CHECK-NEXT: [[SMAX:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP1]], i64 [[SMAX]])
; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP10]], i64 [[SMAX]])
; CHECK-NEXT: [[TMP3:%.*]] = add nuw nsw i64 [[UMIN]], 1
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ule i64 [[TMP3]], 4
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
Expand Down Expand Up @@ -788,7 +794,9 @@ define i64 @multi_exit_4_exit_count_with_urem_by_value_in_latch(ptr %dst, i64 %N
; CHECK-LABEL: define i64 @multi_exit_4_exit_count_with_urem_by_value_in_latch(
; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]]) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = udiv i64 42, [[N]]
; CHECK-NEXT: [[TMP11:%.*]] = freeze i64 [[N]]
; CHECK-NEXT: [[TMP12:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP11]], i64 1)
; CHECK-NEXT: [[TMP0:%.*]] = udiv i64 42, [[TMP12]]
; CHECK-NEXT: [[TMP1:%.*]] = mul nuw i64 [[N]], [[TMP0]]
; CHECK-NEXT: [[TMP2:%.*]] = sub i64 42, [[TMP1]]
; CHECK-NEXT: [[SMAX1:%.*]] = call i64 @llvm.smax.i64(i64 [[TMP2]], i64 0)
Expand Down Expand Up @@ -1007,10 +1015,12 @@ define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch1(ptr %dst, i64 %
; CHECK-LABEL: define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch1(
; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]]) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = udiv i64 42, [[N]]
; CHECK-NEXT: [[TMP8:%.*]] = freeze i64 [[TMP0]]
; CHECK-NEXT: [[TMP0:%.*]] = freeze i64 [[N]]
; CHECK-NEXT: [[TMP8:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP0]], i64 1)
; CHECK-NEXT: [[TMP9:%.*]] = udiv i64 42, [[TMP8]]
; CHECK-NEXT: [[TMP10:%.*]] = freeze i64 [[TMP9]]
; CHECK-NEXT: [[SMAX:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP8]], i64 [[SMAX]])
; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP10]], i64 [[SMAX]])
; CHECK-NEXT: [[TMP1:%.*]] = add nuw nsw i64 [[UMIN]], 1
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ule i64 [[TMP1]], 4
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
Expand Down Expand Up @@ -1072,6 +1082,46 @@ exit:
ret i64 %p
}

define i64 @multi_exit_exit_count_with_udiv_by_0_in_latch(ptr %dst, i64 %N) {
; CHECK-LABEL: define i64 @multi_exit_exit_count_with_udiv_by_0_in_latch(
; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]]) {
; CHECK-NEXT: entry:
; CHECK-NEXT: br label [[LOOP_HEADER:%.*]]
; CHECK: loop.header:
; CHECK-NEXT: [[IV:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[LOOP_LATCH:%.*]] ]
; CHECK-NEXT: [[GEP:%.*]] = getelementptr inbounds i32, ptr [[DST]], i64 [[IV]]
; CHECK-NEXT: store i32 1, ptr [[GEP]], align 4
; CHECK-NEXT: [[C_0:%.*]] = icmp slt i64 [[IV]], [[N]]
; CHECK-NEXT: br i1 [[C_0]], label [[LOOP_LATCH]], label [[EXIT:%.*]]
; CHECK: loop.latch:
; CHECK-NEXT: [[IV_NEXT]] = add i64 [[IV]], 1
; CHECK-NEXT: [[D:%.*]] = udiv i64 42, 0
; CHECK-NEXT: [[C_1:%.*]] = icmp slt i64 [[IV]], [[D]]
; CHECK-NEXT: br i1 [[C_1]], label [[LOOP_HEADER]], label [[EXIT]]
; CHECK: exit:
; CHECK-NEXT: [[P:%.*]] = phi i64 [ 1, [[LOOP_HEADER]] ], [ 0, [[LOOP_LATCH]] ]
; CHECK-NEXT: ret i64 [[P]]
;
entry:
br label %loop.header

loop.header:
%iv = phi i64 [ 0, %entry ], [ %iv.next, %loop.latch ]
%gep = getelementptr inbounds i32, ptr %dst, i64 %iv
store i32 1, ptr %gep
%c.0 = icmp slt i64 %iv, %N
br i1 %c.0, label %loop.latch, label %exit

loop.latch:
%iv.next = add i64 %iv, 1
%d = udiv i64 42, 0
%c.1 = icmp slt i64 %iv, %d
br i1 %c.1, label %loop.header, label %exit

exit:
%p = phi i64 [ 1, %loop.header ], [ 0, %loop.latch]
ret i64 %p
}

;.
; CHECK: [[LOOP0]] = distinct !{[[LOOP0]], [[META1:![0-9]+]], [[META2:![0-9]+]]}
Expand Down