Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
/// "expanded" form.
bool LSRMode;

/// When true, rewrite any divisors of UDiv expressions that may be 0 to
/// umax(Divisor, 1) to avoid introducing UB. If the divisor may be poison,
/// freeze it first.
bool SafeUDivMode = false;

typedef IRBuilder<InstSimplifyFolder, IRBuilderCallbackInserter> BuilderType;
BuilderType Builder;

Expand Down Expand Up @@ -300,6 +305,9 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
/// location and their operands are defined at this location.
bool isSafeToExpandAt(const SCEV *S, const Instruction *InsertionPoint) const;

static bool isSafeToExpand(const SCEV *S, bool CanonicalMode,
ScalarEvolution &SE);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see where this new method is used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left over from earlier versions, removed, thanks


/// Insert code to directly compute the specified SCEV expression into the
/// program. The code is inserted into the specified block.
Value *expandCodeFor(const SCEV *SH, Type *Ty, BasicBlock::iterator I);
Expand Down Expand Up @@ -418,6 +426,11 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
BasicBlock::iterator findInsertPointAfter(Instruction *I,
Instruction *MustDominate) const;

static const SCEV *rewriteExpressionToRemoveUB(const SCEV *BTC, Loop *L,
ScalarEvolution &SE);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused prototype.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dropped, thanks!


void setSafeUDivMode() { SafeUDivMode = true; }

private:
LLVMContext &getContext() const { return SE.getContext(); }

Expand Down
11 changes: 9 additions & 2 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4304,12 +4304,19 @@ ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind,
}

for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
bool MayBeUB = SCEVExprContains(Ops[i], [](const SCEV *S) {
auto *UDiv = dyn_cast<SCEVUDivExpr>(S);
return UDiv && !isa<SCEVConstant>(UDiv->getOperand(1));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to check for zero constant, I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to check for !isKnownNonZero() and added comment, thanks!

});

if (MayBeUB)
continue;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should also change the doc comment on SCEVSequentialMinMaxExpr for the new semantics introduced here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added thanks!

// We can replace %x umin_seq %y with %x umin %y if either:
// * %y being poison implies %x is also poison.
// * %x cannot be the saturating value (e.g. zero for umin).
if (::impliesPoison(Ops[i], Ops[i - 1]) ||
isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
SaturationPoint)) {
(isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
SaturationPoint))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left over change, removed, thanks!

SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
Ops[i - 1] = getMinMaxExpr(
SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(Kind),
Expand Down
15 changes: 14 additions & 1 deletion llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,12 @@ Value *SCEVExpander::visitUDivExpr(const SCEVUDivExpr *S) {
}

Value *RHS = expand(S->getRHS());
if (SafeUDivMode && !SE.isKnownNonZero(S->getRHS())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic here doesn't look quite right wrt the poison handling. This code does nothing if isKnownNonZero (but maybe poison). We should at least leave a TODO here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to check to only consider non-zero constants safe, thanks.
Maybe worth introducing a isKnownNonPoison() helper?

if (!isa<SCEVConstant>(S->getRHS()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can now use the isGuaranteedNotToBePoison check instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated. This removes the freezes in a few cases, but that should be fine as in those cases the function would always trigger UB, so expanding the UDiv before the loop simply triggers UB earlier (e.g. multi_exit_4_exit_count_with_udiv_by_value_in_latch). This should be fine I think (https://alive2.llvm.org/ce/z/anDrW9).

Added additional variants where this is not the case (multi_exit_4_exit_count_with_udiv_by_value_in_latch_call_before_loop, multi_exit_4_exit_count_with_udiv_by_value_in_latch_loop_may_not_execute) and there's also multi_exit_4_exit_count_with_udiv_by_value_in_latch_different_bounds

RHS = Builder.CreateFreeze(RHS);
RHS = Builder.CreateIntrinsic(RHS->getType(), Intrinsic::umax,
{RHS, ConstantInt::get(RHS->getType(), 1)});
}
return InsertBinop(Instruction::UDiv, LHS, RHS, SCEV::FlagAnyWrap,
/*IsSafeToHoist*/ SE.isKnownNonZero(S->getRHS()));
}
Expand Down Expand Up @@ -1358,11 +1364,13 @@ Value *SCEVExpander::visitSignExtendExpr(const SCEVSignExtendExpr *S) {
Value *SCEVExpander::expandMinMaxExpr(const SCEVNAryExpr *S,
Intrinsic::ID IntrinID, Twine Name,
bool IsSequential) {
SafeUDivMode = true;
Value *LHS = expand(S->getOperand(S->getNumOperands() - 1));
Type *Ty = LHS->getType();
if (IsSequential)
LHS = Builder.CreateFreeze(LHS);
for (int i = S->getNumOperands() - 2; i >= 0; --i) {
SafeUDivMode = i != 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to save/restore the mode, in case you have nested min/max.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to save the initial mode, and updated the code to not set it to false if it was already set.

And restored the original value on exit.

The earlier version was also enabling it if IsSequential was false. Also fixed

Value *RHS = expand(S->getOperand(i));
if (IsSequential && i != 0)
RHS = Builder.CreateFreeze(RHS);
Expand Down Expand Up @@ -2315,12 +2323,17 @@ struct SCEVFindUnsafe {
};
} // namespace

bool SCEVExpander::isSafeToExpand(const SCEV *S) const {
bool SCEVExpander::isSafeToExpand(const SCEV *S, bool CanonicalMode,
ScalarEvolution &SE) {
SCEVFindUnsafe Search(SE, CanonicalMode);
visitAll(S, Search);
return !Search.IsUnsafe;
}

bool SCEVExpander::isSafeToExpand(const SCEV *S) const {
return isSafeToExpand(S, CanonicalMode, SE);
}

bool SCEVExpander::isSafeToExpandAt(const SCEV *S,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed those changes

const Instruction *InsertionPoint) const {
if (!isSafeToExpand(S))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -463,9 +463,12 @@ define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch(ptr %dst, i64 %N
; CHECK-LABEL: define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch(
; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]]) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP9:%.*]] = freeze i64 [[N]]
; CHECK-NEXT: [[TMP10:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP9]], i64 1)
; CHECK-NEXT: [[TMP0:%.*]] = udiv i64 42, [[TMP10]]
; CHECK-NEXT: [[TMP8:%.*]] = freeze i64 [[TMP0]]
; CHECK-NEXT: [[SMAX:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
; CHECK-NEXT: [[TMP0:%.*]] = udiv i64 42, [[N]]
; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[SMAX]], i64 [[TMP0]])
; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP8]], i64 [[SMAX]])
; CHECK-NEXT: [[TMP1:%.*]] = add nuw nsw i64 [[UMIN]], 1
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ule i64 [[TMP1]], 4
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
Expand Down Expand Up @@ -529,7 +532,9 @@ define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch_different_bounds
; CHECK-LABEL: define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch_different_bounds(
; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]], i64 [[M:%.*]]) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP2:%.*]] = udiv i64 42, [[M]]
; CHECK-NEXT: [[TMP0:%.*]] = freeze i64 [[M]]
; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP0]], i64 1)
; CHECK-NEXT: [[TMP2:%.*]] = udiv i64 42, [[TMP1]]
; CHECK-NEXT: [[TMP3:%.*]] = freeze i64 [[TMP2]]
; CHECK-NEXT: [[SMAX:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP3]], i64 [[SMAX]])
Expand Down Expand Up @@ -598,9 +603,12 @@ define i64 @multi_exit_4_exit_count_with_udiv_by_frozen_value_in_latch(ptr %dst,
; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]]) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[FR_N:%.*]] = freeze i64 [[N]]
; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
; CHECK-NEXT: [[TMP0:%.*]] = udiv i64 42, [[FR_N]]
; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP2]], i64 [[TMP0]])
; CHECK-NEXT: [[TMP0:%.*]] = freeze i64 [[FR_N]]
; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP0]], i64 1)
; CHECK-NEXT: [[TMP2:%.*]] = udiv i64 42, [[TMP1]]
; CHECK-NEXT: [[TMP10:%.*]] = freeze i64 [[TMP2]]
; CHECK-NEXT: [[SMAX:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP10]], i64 [[SMAX]])
; CHECK-NEXT: [[TMP3:%.*]] = add nuw nsw i64 [[UMIN]], 1
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ule i64 [[TMP3]], 4
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
Expand Down Expand Up @@ -786,12 +794,15 @@ define i64 @multi_exit_4_exit_count_with_urem_by_value_in_latch(ptr %dst, i64 %N
; CHECK-LABEL: define i64 @multi_exit_4_exit_count_with_urem_by_value_in_latch(
; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]]) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[SMAX:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
; CHECK-NEXT: [[TMP0:%.*]] = udiv i64 42, [[N]]
; CHECK-NEXT: [[TMP11:%.*]] = freeze i64 [[N]]
; CHECK-NEXT: [[TMP12:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP11]], i64 1)
; CHECK-NEXT: [[TMP0:%.*]] = udiv i64 42, [[TMP12]]
; CHECK-NEXT: [[TMP1:%.*]] = mul nuw i64 [[N]], [[TMP0]]
; CHECK-NEXT: [[TMP2:%.*]] = sub i64 42, [[TMP1]]
; CHECK-NEXT: [[SMAX1:%.*]] = call i64 @llvm.smax.i64(i64 [[TMP2]], i64 0)
; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[SMAX]], i64 [[SMAX1]])
; CHECK-NEXT: [[TMP10:%.*]] = freeze i64 [[SMAX1]]
; CHECK-NEXT: [[SMAX2:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP10]], i64 [[SMAX2]])
; CHECK-NEXT: [[TMP3:%.*]] = add nuw i64 [[UMIN]], 1
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ule i64 [[TMP3]], 4
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
Expand Down Expand Up @@ -1004,9 +1015,12 @@ define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch1(ptr %dst, i64 %
; CHECK-LABEL: define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch1(
; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]]) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = freeze i64 [[N]]
; CHECK-NEXT: [[TMP8:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP0]], i64 1)
; CHECK-NEXT: [[TMP9:%.*]] = udiv i64 42, [[TMP8]]
; CHECK-NEXT: [[TMP10:%.*]] = freeze i64 [[TMP9]]
; CHECK-NEXT: [[SMAX:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
; CHECK-NEXT: [[TMP0:%.*]] = udiv i64 42, [[N]]
; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[SMAX]], i64 [[TMP0]])
; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP10]], i64 [[SMAX]])
; CHECK-NEXT: [[TMP1:%.*]] = add nuw nsw i64 [[UMIN]], 1
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ule i64 [[TMP1]], 4
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
Expand Down