Skip to content

Commit 6d93d4d

Browse files
committed
[DA] Add getNoOverflowPredicate() to ScalarEvolution for reuse
Refactor the overflow check runtime predicate generation to reuse the same pattern used by ScalarEvolution::willNotOverflow(). Instead of duplicating the ext(LHS op RHS) == ext(LHS) op ext(RHS) pattern in Delinearization.cpp, add a new getNoOverflowPredicate() method to ScalarEvolution that returns the predicate (or nullptr if no-overflow is already provable at compile time). This addresses review feedback to avoid code duplication between willNotOverflow() and the runtime predicate generation in delinearization.
1 parent 240badd commit 6d93d4d

26 files changed

+651
-626
lines changed

llvm/include/llvm/Analysis/ScalarEvolution.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,14 @@ class ScalarEvolution {
544544
const SCEV *LHS, const SCEV *RHS,
545545
const Instruction *CtxI = nullptr);
546546

547+
/// Get the predicate that, if true at runtime, proves that the binary
548+
/// operation \p BinOp between \p LHS and \p RHS does not have
549+
/// signed/unsigned overflow (depending on \p Signed). Returns the
550+
/// predicate, or nullptr if no-overflow is already provable at compile time.
551+
LLVM_ABI const SCEVPredicate *
552+
getNoOverflowPredicate(Instruction::BinaryOps BinOp, bool Signed,
553+
const SCEV *LHS, const SCEV *RHS);
554+
547555
/// Parse NSW/NUW flags from add/sub/mul IR binary operation \p Op into
548556
/// SCEV no-wrap flags, and deduce flag[s] that aren't known yet.
549557
/// Does not mutate the original instruction. Returns std::nullopt if it could

llvm/lib/Analysis/Delinearization.cpp

Lines changed: 20 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -772,37 +772,19 @@ bool llvm::validateDelinearizationResult(
772772
SE.getNoopOrSignExtend(B, WiderType)};
773773
};
774774

775-
// Get a type with twice the bit width of T.
776-
auto GetWiderType = [&](Type *T) -> Type * {
777-
unsigned BitWidth = SE.getTypeSizeInBits(T);
778-
return IntegerType::get(T->getContext(), BitWidth * 2);
779-
};
780-
781775
// Check if the result of A + B (signed) does not overflow. If it can be
782776
// proven at compile-time, return the result. If it might overflow and Assume
783777
// is provided, add a runtime equality predicate and return the result.
784778
// Otherwise return nullptr.
785779
auto AddNoOverflow = [&](const SCEV *A, const SCEV *B) -> const SCEV * {
786780
std::tie(A, B) = UnifyTypes(A, B);
787-
if (SE.willNotOverflow(Instruction::Add, /*IsSigned=*/true, A, B))
788-
return SE.getAddExpr(A, B);
789-
if (!Assume)
790-
return nullptr;
791-
792-
// Compute the addition in a wider type to detect overflow.
793-
// If (sext A) + (sext B) == sext(A + B), then A + B does not overflow.
794-
Type *OrigTy = A->getType();
795-
Type *WiderTy = GetWiderType(OrigTy);
796-
const SCEV *AWide = SE.getSignExtendExpr(A, WiderTy);
797-
const SCEV *BWide = SE.getSignExtendExpr(B, WiderTy);
798-
const SCEV *SumWide = SE.getAddExpr(AWide, BWide);
799-
const SCEV *Sum = SE.getAddExpr(A, B);
800-
const SCEV *SumExtended = SE.getSignExtendExpr(Sum, WiderTy);
801-
// Add predicate: (sext A) + (sext B) == sext(A + B).
802-
if (SumWide != SumExtended &&
803-
!SE.isKnownPredicate(ICmpInst::ICMP_EQ, SumWide, SumExtended))
804-
Assume->push_back(SE.getEqualPredicate(SumWide, SumExtended));
805-
return Sum;
781+
if (const auto *Pred = SE.getNoOverflowPredicate(Instruction::Add,
782+
/*Signed=*/true, A, B)) {
783+
if (!Assume)
784+
return nullptr;
785+
Assume->push_back(Pred);
786+
}
787+
return SE.getAddExpr(A, B);
806788
};
807789

808790
// Check if the result of A * B (signed) does not overflow. If it can be
@@ -811,50 +793,26 @@ bool llvm::validateDelinearizationResult(
811793
// Otherwise return nullptr.
812794
auto MulNoOverflow = [&](const SCEV *A, const SCEV *B) -> const SCEV * {
813795
std::tie(A, B) = UnifyTypes(A, B);
814-
if (SE.willNotOverflow(Instruction::Mul, /*IsSigned=*/true, A, B))
815-
return SE.getMulExpr(A, B);
816-
if (!Assume)
817-
return nullptr;
818-
819-
// Compute the multiplication in a wider type to detect overflow.
820-
// If (sext A) * (sext B) == sext(A * B), then A * B does not overflow.
821-
Type *OrigTy = A->getType();
822-
Type *WiderTy = GetWiderType(OrigTy);
823-
const SCEV *AWide = SE.getSignExtendExpr(A, WiderTy);
824-
const SCEV *BWide = SE.getSignExtendExpr(B, WiderTy);
825-
const SCEV *ProdWide = SE.getMulExpr(AWide, BWide);
826-
const SCEV *Prod = SE.getMulExpr(A, B);
827-
const SCEV *ProdExtended = SE.getSignExtendExpr(Prod, WiderTy);
828-
// Add predicate: (sext A) * (sext B) == sext(A * B).
829-
if (ProdWide != ProdExtended &&
830-
!SE.isKnownPredicate(ICmpInst::ICMP_EQ, ProdWide, ProdExtended))
831-
Assume->push_back(SE.getEqualPredicate(ProdWide, ProdExtended));
832-
return Prod;
796+
if (const auto *Pred = SE.getNoOverflowPredicate(Instruction::Mul,
797+
/*Signed=*/true, A, B)) {
798+
if (!Assume)
799+
return nullptr;
800+
Assume->push_back(Pred);
801+
}
802+
return SE.getMulExpr(A, B);
833803
};
834804

835805
// Check if the result of A - B (signed) does not overflow. If it can be
836806
// proven at compile-time or if Assume is provided (adding a runtime
837807
// predicate), return true. Otherwise return false.
838808
auto SubNoOverflow = [&](const SCEV *A, const SCEV *B) -> bool {
839809
std::tie(A, B) = UnifyTypes(A, B);
840-
if (SE.willNotOverflow(Instruction::Sub, /*IsSigned=*/true, A, B))
841-
return true;
842-
if (!Assume)
843-
return false;
844-
845-
// Compute the subtraction in a wider type to detect overflow.
846-
// If (sext A) - (sext B) == sext(A - B), then A - B does not overflow.
847-
Type *OrigTy = A->getType();
848-
Type *WiderTy = GetWiderType(OrigTy);
849-
const SCEV *AWide = SE.getSignExtendExpr(A, WiderTy);
850-
const SCEV *BWide = SE.getSignExtendExpr(B, WiderTy);
851-
const SCEV *DiffWide = SE.getMinusSCEV(AWide, BWide);
852-
const SCEV *Diff = SE.getMinusSCEV(A, B);
853-
const SCEV *DiffExtended = SE.getSignExtendExpr(Diff, WiderTy);
854-
// Add predicate: (sext A) - (sext B) == sext(A - B).
855-
if (DiffWide != DiffExtended &&
856-
!SE.isKnownPredicate(ICmpInst::ICMP_EQ, DiffWide, DiffExtended))
857-
Assume->push_back(SE.getEqualPredicate(DiffWide, DiffExtended));
810+
if (const auto *Pred = SE.getNoOverflowPredicate(Instruction::Sub,
811+
/*Signed=*/true, A, B)) {
812+
if (!Assume)
813+
return false;
814+
Assume->push_back(Pred);
815+
}
858816
return true;
859817
};
860818

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2376,6 +2376,49 @@ bool ScalarEvolution::willNotOverflow(Instruction::BinaryOps BinOp, bool Signed,
23762376
}
23772377
}
23782378

2379+
const SCEVPredicate *
2380+
ScalarEvolution::getNoOverflowPredicate(Instruction::BinaryOps BinOp,
2381+
bool Signed, const SCEV *LHS,
2382+
const SCEV *RHS) {
2383+
// First check if no-overflow can be proven at compile time.
2384+
if (willNotOverflow(BinOp, Signed, LHS, RHS))
2385+
return nullptr;
2386+
2387+
const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2388+
SCEV::NoWrapFlags, unsigned);
2389+
switch (BinOp) {
2390+
default:
2391+
llvm_unreachable("Unsupported binary op");
2392+
case Instruction::Add:
2393+
Operation = &ScalarEvolution::getAddExpr;
2394+
break;
2395+
case Instruction::Sub:
2396+
Operation = &ScalarEvolution::getMinusSCEV;
2397+
break;
2398+
case Instruction::Mul:
2399+
Operation = &ScalarEvolution::getMulExpr;
2400+
break;
2401+
}
2402+
2403+
const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2404+
Signed ? &ScalarEvolution::getSignExtendExpr
2405+
: &ScalarEvolution::getZeroExtendExpr;
2406+
2407+
// Build predicate: ext(LHS op RHS) == ext(LHS) op ext(RHS)
2408+
auto *NarrowTy = cast<IntegerType>(LHS->getType());
2409+
auto *WideTy =
2410+
IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2411+
2412+
const SCEV *A = (this->*Extension)(
2413+
(this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2414+
const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2415+
const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2416+
const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2417+
2418+
// Return the equality predicate.
2419+
return getEqualPredicate(A, B);
2420+
}
2421+
23792422
std::optional<SCEV::NoWrapFlags>
23802423
ScalarEvolution::getStrengthenedNoWrapFlagsFromBinOp(
23812424
const OverflowingBinaryOperator *OBO) {

llvm/test/Analysis/DependenceAnalysis/Banerjee.ll

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -139,27 +139,27 @@ define void @banerjee1(ptr %A, ptr %B, i64 %m, i64 %n) nounwind uwtable ssp {
139139
; DELIN-NEXT: da analyze - consistent output [0 0]!
140140
; DELIN-NEXT: Runtime Assumptions:
141141
; DELIN-NEXT: Compare predicate: {1,+,1}<nuw><nsw><%for.body3> slt) 10
142-
; DELIN-NEXT: Equal predicate: {80,+,80}<nuw><nsw><%for.cond1.preheader> == (sext i64 {80,+,80}<%for.cond1.preheader> to i128)
143-
; DELIN-NEXT: Equal predicate: (80 + (sext i64 {80,+,80}<%for.cond1.preheader> to i128))<nsw> == (sext i64 {160,+,80}<%for.cond1.preheader> to i128)
144-
; DELIN-NEXT: Equal predicate: (-1 + (sext i64 {160,+,80}<%for.cond1.preheader> to i128))<nsw> == (15 + (sext i64 {144,+,80}<%for.cond1.preheader> to i128))<nuw><nsw>
142+
; DELIN-NEXT: Equal predicate: (sext i64 {80,+,80}<%for.cond1.preheader> to i128) == {80,+,80}<nuw><nsw><%for.cond1.preheader>
143+
; DELIN-NEXT: Equal predicate: (sext i64 {160,+,80}<%for.cond1.preheader> to i128) == (80 + (sext i64 {80,+,80}<%for.cond1.preheader> to i128))<nsw>
144+
; DELIN-NEXT: Equal predicate: (15 + (sext i64 {144,+,80}<%for.cond1.preheader> to i128))<nuw><nsw> == (-1 + (sext i64 {160,+,80}<%for.cond1.preheader> to i128))<nsw>
145145
; DELIN-NEXT: Compare predicate: {1,+,1}<nuw><nsw><%for.body3> slt) 10
146146
; DELIN-NEXT: Src: store i64 0, ptr %arrayidx, align 8 --> Dst: %2 = load i64, ptr %arrayidx6, align 8
147147
; DELIN-NEXT: da analyze - consistent flow [0 1]!
148148
; DELIN-NEXT: Runtime Assumptions:
149149
; DELIN-NEXT: Compare predicate: {1,+,1}<nuw><nsw><%for.body3> slt) 10
150-
; DELIN-NEXT: Equal predicate: {80,+,80}<nuw><nsw><%for.cond1.preheader> == (sext i64 {80,+,80}<%for.cond1.preheader> to i128)
151-
; DELIN-NEXT: Equal predicate: (80 + (sext i64 {80,+,80}<%for.cond1.preheader> to i128))<nsw> == (sext i64 {160,+,80}<%for.cond1.preheader> to i128)
152-
; DELIN-NEXT: Equal predicate: (-1 + (sext i64 {160,+,80}<%for.cond1.preheader> to i128))<nsw> == (15 + (sext i64 {144,+,80}<%for.cond1.preheader> to i128))<nuw><nsw>
150+
; DELIN-NEXT: Equal predicate: (sext i64 {80,+,80}<%for.cond1.preheader> to i128) == {80,+,80}<nuw><nsw><%for.cond1.preheader>
151+
; DELIN-NEXT: Equal predicate: (sext i64 {160,+,80}<%for.cond1.preheader> to i128) == (80 + (sext i64 {80,+,80}<%for.cond1.preheader> to i128))<nsw>
152+
; DELIN-NEXT: Equal predicate: (15 + (sext i64 {144,+,80}<%for.cond1.preheader> to i128))<nuw><nsw> == (-1 + (sext i64 {160,+,80}<%for.cond1.preheader> to i128))<nsw>
153153
; DELIN-NEXT: Compare predicate: {0,+,1}<nuw><nsw><%for.body3> slt) 10
154154
; DELIN-NEXT: Src: store i64 0, ptr %arrayidx, align 8 --> Dst: store i64 %2, ptr %B.addr.12, align 8
155155
; DELIN-NEXT: da analyze - confused!
156156
; DELIN-NEXT: Src: %2 = load i64, ptr %arrayidx6, align 8 --> Dst: %2 = load i64, ptr %arrayidx6, align 8
157157
; DELIN-NEXT: da analyze - consistent input [0 0]!
158158
; DELIN-NEXT: Runtime Assumptions:
159159
; DELIN-NEXT: Compare predicate: {0,+,1}<nuw><nsw><%for.body3> slt) 10
160-
; DELIN-NEXT: Equal predicate: {80,+,80}<nuw><nsw><%for.cond1.preheader> == (sext i64 {80,+,80}<%for.cond1.preheader> to i128)
161-
; DELIN-NEXT: Equal predicate: (80 + (sext i64 {80,+,80}<%for.cond1.preheader> to i128))<nsw> == (sext i64 {160,+,80}<%for.cond1.preheader> to i128)
162-
; DELIN-NEXT: Equal predicate: (-1 + (sext i64 {160,+,80}<%for.cond1.preheader> to i128))<nsw> == (15 + (sext i64 {144,+,80}<%for.cond1.preheader> to i128))<nuw><nsw>
160+
; DELIN-NEXT: Equal predicate: (sext i64 {80,+,80}<%for.cond1.preheader> to i128) == {80,+,80}<nuw><nsw><%for.cond1.preheader>
161+
; DELIN-NEXT: Equal predicate: (sext i64 {160,+,80}<%for.cond1.preheader> to i128) == (80 + (sext i64 {80,+,80}<%for.cond1.preheader> to i128))<nsw>
162+
; DELIN-NEXT: Equal predicate: (15 + (sext i64 {144,+,80}<%for.cond1.preheader> to i128))<nuw><nsw> == (-1 + (sext i64 {160,+,80}<%for.cond1.preheader> to i128))<nsw>
163163
; DELIN-NEXT: Compare predicate: {0,+,1}<nuw><nsw><%for.body3> slt) 10
164164
; DELIN-NEXT: Src: %2 = load i64, ptr %arrayidx6, align 8 --> Dst: store i64 %2, ptr %B.addr.12, align 8
165165
; DELIN-NEXT: da analyze - confused!

0 commit comments

Comments
 (0)