Skip to content

[SCEV] Apply loop guards to End in computeMaxBECountForLT #116187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions llvm/include/llvm/Analysis/ScalarEvolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -2218,8 +2218,8 @@ class ScalarEvolution {
/// actually doesn't, or we'd have to immediately execute UB)
/// We *don't* assert these preconditions so please be careful.
const SCEV *computeMaxBECountForLT(const SCEV *Start, const SCEV *Stride,
const SCEV *End, unsigned BitWidth,
bool IsSigned);
const SCEV *End, const Loop *L,
unsigned BitWidth, bool IsSigned);

/// Verify if an linear IV with positive stride can overflow when in a
/// less-than comparison, knowing the invariant term of the comparison,
Expand Down
19 changes: 10 additions & 9 deletions llvm/lib/Analysis/ScalarEvolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12857,11 +12857,10 @@ const SCEV *ScalarEvolution::getUDivCeilSCEV(const SCEV *N, const SCEV *D) {
return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
}

const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
const SCEV *Stride,
const SCEV *End,
unsigned BitWidth,
bool IsSigned) {
const SCEV *
ScalarEvolution::computeMaxBECountForLT(const SCEV *Start, const SCEV *Stride,
const SCEV *End, const Loop *L,
unsigned BitWidth, bool IsSigned) {
// The logic in this function assumes we can represent a positive stride.
// If we can't, the backedge-taken count must be zero.
if (IsSigned && BitWidth == 1)
Expand Down Expand Up @@ -12895,8 +12894,10 @@ const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
// the case End = RHS of the loop termination condition. This is safe because
// in the other case (End - Start) is zero, leading to a zero maximum backedge
// taken count.
APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
: APIntOps::umin(getUnsignedRangeMax(End), Limit);
const SCEV *GuardedEnd = applyLoopGuards(End, L);
APInt MaxEnd = IsSigned
? APIntOps::smin(getSignedRangeMax(GuardedEnd), Limit)
: APIntOps::umin(getUnsignedRangeMax(GuardedEnd), Limit);

// MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
Expand Down Expand Up @@ -13150,7 +13151,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
// loop (RHS), and the fact that IV does not overflow (which is
// checked above).
const SCEV *MaxBECount = computeMaxBECountForLT(
Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
Start, Stride, RHS, L, getTypeSizeInBits(LHS->getType()), IsSigned);
return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
MaxBECount, false /*MaxOrZero*/, Predicates);
}
Expand Down Expand Up @@ -13334,7 +13335,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
MaxOrZero = true;
} else {
ConstantMaxBECount = computeMaxBECountForLT(
Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
Start, Stride, RHS, L, getTypeSizeInBits(LHS->getType()), IsSigned);
}

if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
; RUN: -prefer-predicate-over-epilogue=predicate-dont-vectorize \
; RUN: -mtriple=riscv64 -mattr=+v -S < %s | FileCheck %s

; TODO: We know the IV will never overflow here so we can skip the overflow
; check
; We know the IV will never overflow here so we can skip the overflow check

define void @trip_count_max_1024(ptr %p, i64 %tc) vscale_range(2, 1024) {
; CHECK-LABEL: define void @trip_count_max_1024(
Expand All @@ -15,11 +14,7 @@ define void @trip_count_max_1024(ptr %p, i64 %tc) vscale_range(2, 1024) {
; CHECK-NEXT: br i1 [[GUARD]], label %[[EXIT:.*]], label %[[LOOP_PREHEADER:.*]]
; CHECK: [[LOOP_PREHEADER]]:
; CHECK-NEXT: [[UMAX:%.*]] = call i64 @llvm.umax.i64(i64 [[TC]], i64 1)
; CHECK-NEXT: [[TMP0:%.*]] = sub i64 -1, [[UMAX]]
; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
; CHECK-NEXT: [[TMP2:%.*]] = mul i64 [[TMP1]], 2
; CHECK-NEXT: [[TMP3:%.*]] = icmp ult i64 [[TMP0]], [[TMP2]]
; CHECK-NEXT: br i1 [[TMP3]], label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
; CHECK-NEXT: br i1 false, label %[[SCALAR_PH:.*]], label %[[VECTOR_PH:.*]]
; CHECK: [[VECTOR_PH]]:
; CHECK-NEXT: [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
; CHECK-NEXT: [[TMP5:%.*]] = mul i64 [[TMP4]], 2
Expand Down
Loading