Skip to content

Commit a5ef2e0

Browse files
committed
Return "[SCEV] Prove implicaitons via AddRec start"
The initial version of the patch was reverted because it missed the check that the predicate being proved is actually guarded by this check on 1st iteration. If it was not executed on 1st iteration (but possibly executes after that), then it is incorrect to use reasoning about IV start to prove it. Added the test where the miscompile was seen. Unfortunately, my attempts to reduce it with bugpoint did not succeed; it can further be reduced when we understand how to do it without losing the initial bug's notion. Returning assuming the miscompiles are now gone. Differential Revision: https://reviews.llvm.org/D88208
1 parent 6dcbea8 commit a5ef2e0

File tree

4 files changed

+539
-32
lines changed

4 files changed

+539
-32
lines changed

llvm/include/llvm/Analysis/ScalarEvolution.h

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1677,23 +1677,30 @@ class ScalarEvolution {
16771677
getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB) const;
16781678

16791679
/// Test whether the condition described by Pred, LHS, and RHS is true
1680-
/// whenever the given FoundCondValue value evaluates to true.
1680+
/// whenever the given FoundCondValue value evaluates to true in given
1681+
/// Context. If Context is nullptr, then the found predicate is true
1682+
/// everywhere.
16811683
bool isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
1682-
const Value *FoundCondValue, bool Inverse);
1684+
const Value *FoundCondValue, bool Inverse,
1685+
const Instruction *Context = nullptr);
16831686

16841687
/// Test whether the condition described by Pred, LHS, and RHS is true
16851688
/// whenever the condition described by FoundPred, FoundLHS, FoundRHS is
1686-
/// true.
1689+
/// true in given Context. If Context is nullptr, then the found predicate is
1690+
/// true everywhere.
16871691
bool isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
16881692
ICmpInst::Predicate FoundPred, const SCEV *FoundLHS,
1689-
const SCEV *FoundRHS);
1693+
const SCEV *FoundRHS,
1694+
const Instruction *Context = nullptr);
16901695

16911696
/// Test whether the condition described by Pred, LHS, and RHS is true
16921697
/// whenever the condition described by Pred, FoundLHS, and FoundRHS is
1693-
/// true.
1698+
/// true in given Context. If Context is nullptr, then the found predicate is
1699+
/// true everywhere.
16941700
bool isImpliedCondOperands(ICmpInst::Predicate Pred, const SCEV *LHS,
16951701
const SCEV *RHS, const SCEV *FoundLHS,
1696-
const SCEV *FoundRHS);
1702+
const SCEV *FoundRHS,
1703+
const Instruction *Context = nullptr);
16971704

16981705
/// Test whether the condition described by Pred, LHS, and RHS is true
16991706
/// whenever the condition described by Pred, FoundLHS, and FoundRHS is
@@ -1740,6 +1747,18 @@ class ScalarEvolution {
17401747
const SCEV *FoundLHS,
17411748
const SCEV *FoundRHS);
17421749

1750+
/// Test whether the condition described by Pred, LHS, and RHS is true
1751+
/// whenever the condition described by Pred, FoundLHS, and FoundRHS is
1752+
/// true.
1753+
///
1754+
/// This routine tries to weaken the known condition basing on fact that
1755+
/// FoundLHS is an AddRec.
1756+
bool isImpliedCondOperandsViaAddRecStart(ICmpInst::Predicate Pred,
1757+
const SCEV *LHS, const SCEV *RHS,
1758+
const SCEV *FoundLHS,
1759+
const SCEV *FoundRHS,
1760+
const Instruction *Context);
1761+
17431762
/// Test whether the condition described by Pred, LHS, and RHS is true
17441763
/// whenever the condition described by Pred, FoundLHS, and FoundRHS is
17451764
/// true.

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 84 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9549,15 +9549,16 @@ bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB,
95499549

95509550
// Try to prove (Pred, LHS, RHS) using isImpliedCond.
95519551
auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
9552-
if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse))
9552+
const Instruction *Context = &BB->front();
9553+
if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, Context))
95539554
return true;
95549555
if (ProvingStrictComparison) {
95559556
if (!ProvedNonStrictComparison)
9556-
ProvedNonStrictComparison =
9557-
isImpliedCond(NonStrictPredicate, LHS, RHS, Condition, Inverse);
9557+
ProvedNonStrictComparison = isImpliedCond(NonStrictPredicate, LHS, RHS,
9558+
Condition, Inverse, Context);
95589559
if (!ProvedNonEquality)
9559-
ProvedNonEquality =
9560-
isImpliedCond(ICmpInst::ICMP_NE, LHS, RHS, Condition, Inverse);
9560+
ProvedNonEquality = isImpliedCond(ICmpInst::ICMP_NE, LHS, RHS,
9561+
Condition, Inverse, Context);
95619562
if (ProvedNonStrictComparison && ProvedNonEquality)
95629563
return true;
95639564
}
@@ -9623,7 +9624,8 @@ bool ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
96239624

96249625
bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
96259626
const SCEV *RHS,
9626-
const Value *FoundCondValue, bool Inverse) {
9627+
const Value *FoundCondValue, bool Inverse,
9628+
const Instruction *Context) {
96279629
if (!PendingLoopPredicates.insert(FoundCondValue).second)
96289630
return false;
96299631

@@ -9634,12 +9636,16 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
96349636
if (const BinaryOperator *BO = dyn_cast<BinaryOperator>(FoundCondValue)) {
96359637
if (BO->getOpcode() == Instruction::And) {
96369638
if (!Inverse)
9637-
return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) ||
9638-
isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse);
9639+
return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse,
9640+
Context) ||
9641+
isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse,
9642+
Context);
96399643
} else if (BO->getOpcode() == Instruction::Or) {
96409644
if (Inverse)
9641-
return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) ||
9642-
isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse);
9645+
return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse,
9646+
Context) ||
9647+
isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse,
9648+
Context);
96439649
}
96449650
}
96459651

@@ -9657,14 +9663,14 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
96579663
const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
96589664
const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
96599665

9660-
return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS);
9666+
return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, Context);
96619667
}
96629668

96639669
bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
96649670
const SCEV *RHS,
96659671
ICmpInst::Predicate FoundPred,
9666-
const SCEV *FoundLHS,
9667-
const SCEV *FoundRHS) {
9672+
const SCEV *FoundLHS, const SCEV *FoundRHS,
9673+
const Instruction *Context) {
96689674
// Balance the types.
96699675
if (getTypeSizeInBits(LHS->getType()) <
96709676
getTypeSizeInBits(FoundLHS->getType())) {
@@ -9708,24 +9714,24 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
97089714

97099715
// Check whether the found predicate is the same as the desired predicate.
97109716
if (FoundPred == Pred)
9711-
return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS);
9717+
return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context);
97129718

97139719
// Check whether swapping the found predicate makes it the same as the
97149720
// desired predicate.
97159721
if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
97169722
if (isa<SCEVConstant>(RHS))
9717-
return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS);
9723+
return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS, Context);
97189724
else
9719-
return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred),
9720-
RHS, LHS, FoundLHS, FoundRHS);
9725+
return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred), RHS,
9726+
LHS, FoundLHS, FoundRHS, Context);
97219727
}
97229728

97239729
// Unsigned comparison is the same as signed comparison when both the operands
97249730
// are non-negative.
97259731
if (CmpInst::isUnsigned(FoundPred) &&
97269732
CmpInst::getSignedPredicate(FoundPred) == Pred &&
97279733
isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS))
9728-
return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS);
9734+
return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context);
97299735

97309736
// Check if we can make progress by sharpening ranges.
97319737
if (FoundPred == ICmpInst::ICMP_NE &&
@@ -9762,8 +9768,8 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
97629768
case ICmpInst::ICMP_UGE:
97639769
// We know V `Pred` SharperMin. If this implies LHS `Pred`
97649770
// RHS, we're done.
9765-
if (isImpliedCondOperands(Pred, LHS, RHS, V,
9766-
getConstant(SharperMin)))
9771+
if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
9772+
Context))
97679773
return true;
97689774
LLVM_FALLTHROUGH;
97699775

@@ -9778,22 +9784,23 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
97789784
//
97799785
// If V `Pred` Min implies LHS `Pred` RHS, we're done.
97809786

9781-
if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min)))
9787+
if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min),
9788+
Context))
97829789
return true;
97839790
break;
97849791

97859792
// `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
97869793
case ICmpInst::ICMP_SLE:
97879794
case ICmpInst::ICMP_ULE:
97889795
if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
9789-
LHS, V, getConstant(SharperMin)))
9796+
LHS, V, getConstant(SharperMin), Context))
97909797
return true;
97919798
LLVM_FALLTHROUGH;
97929799

97939800
case ICmpInst::ICMP_SLT:
97949801
case ICmpInst::ICMP_ULT:
97959802
if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
9796-
LHS, V, getConstant(Min)))
9803+
LHS, V, getConstant(Min), Context))
97979804
return true;
97989805
break;
97999806

@@ -9807,11 +9814,12 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
98079814
// Check whether the actual condition is beyond sufficient.
98089815
if (FoundPred == ICmpInst::ICMP_EQ)
98099816
if (ICmpInst::isTrueWhenEqual(Pred))
9810-
if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS))
9817+
if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context))
98119818
return true;
98129819
if (Pred == ICmpInst::ICMP_NE)
98139820
if (!ICmpInst::isTrueWhenEqual(FoundPred))
9814-
if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS))
9821+
if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS,
9822+
Context))
98159823
return true;
98169824

98179825
// Otherwise assume the worst.
@@ -9890,6 +9898,51 @@ Optional<APInt> ScalarEvolution::computeConstantDifference(const SCEV *More,
98909898
return None;
98919899
}
98929900

9901+
bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
9902+
ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
9903+
const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *Context) {
9904+
// Try to recognize the following pattern:
9905+
//
9906+
// FoundRHS = ...
9907+
// ...
9908+
// loop:
9909+
// FoundLHS = {Start,+,W}
9910+
// context_bb: // Basic block from the same loop
9911+
// known(Pred, FoundLHS, FoundRHS)
9912+
//
9913+
// If some predicate is known in the context of a loop, it is also known on
9914+
// each iteration of this loop, including the first iteration. Therefore, in
9915+
// this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
9916+
// prove the original pred using this fact.
9917+
if (!Context)
9918+
return false;
9919+
const BasicBlock *ContextBB = Context->getParent();
9920+
// Make sure AR varies in the context block.
9921+
if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
9922+
const Loop *L = AR->getLoop();
9923+
// Make sure that context belongs to the loop and executes on 1st iteration
9924+
// (if it ever executes at all).
9925+
if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
9926+
return false;
9927+
if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
9928+
return false;
9929+
return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
9930+
}
9931+
9932+
if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
9933+
const Loop *L = AR->getLoop();
9934+
// Make sure that context belongs to the loop and executes on 1st iteration
9935+
// (if it ever executes at all).
9936+
if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
9937+
return false;
9938+
if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
9939+
return false;
9940+
return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
9941+
}
9942+
9943+
return false;
9944+
}
9945+
98939946
bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(
98949947
ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
98959948
const SCEV *FoundLHS, const SCEV *FoundRHS) {
@@ -10080,13 +10133,18 @@ bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred,
1008010133
bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
1008110134
const SCEV *LHS, const SCEV *RHS,
1008210135
const SCEV *FoundLHS,
10083-
const SCEV *FoundRHS) {
10136+
const SCEV *FoundRHS,
10137+
const Instruction *Context) {
1008410138
if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS))
1008510139
return true;
1008610140

1008710141
if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS))
1008810142
return true;
1008910143

10144+
if (isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
10145+
Context))
10146+
return true;
10147+
1009010148
return isImpliedCondOperandsHelper(Pred, LHS, RHS,
1009110149
FoundLHS, FoundRHS) ||
1009210150
// ~x < ~y --> x > y

0 commit comments

Comments
 (0)