@@ -9549,15 +9549,16 @@ bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB,
9549
9549
9550
9550
// Try to prove (Pred, LHS, RHS) using isImpliedCond.
9551
9551
auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
9552
- if (isImpliedCond (Pred, LHS, RHS, Condition, Inverse))
9552
+ const Instruction *Context = &BB->front ();
9553
+ if (isImpliedCond (Pred, LHS, RHS, Condition, Inverse, Context))
9553
9554
return true ;
9554
9555
if (ProvingStrictComparison) {
9555
9556
if (!ProvedNonStrictComparison)
9556
- ProvedNonStrictComparison =
9557
- isImpliedCond (NonStrictPredicate, LHS, RHS, Condition, Inverse);
9557
+ ProvedNonStrictComparison = isImpliedCond (NonStrictPredicate, LHS, RHS,
9558
+ Condition, Inverse, Context );
9558
9559
if (!ProvedNonEquality)
9559
- ProvedNonEquality =
9560
- isImpliedCond (ICmpInst::ICMP_NE, LHS, RHS, Condition, Inverse);
9560
+ ProvedNonEquality = isImpliedCond (ICmpInst::ICMP_NE, LHS, RHS,
9561
+ Condition, Inverse, Context );
9561
9562
if (ProvedNonStrictComparison && ProvedNonEquality)
9562
9563
return true ;
9563
9564
}
@@ -9623,7 +9624,8 @@ bool ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
9623
9624
9624
9625
bool ScalarEvolution::isImpliedCond (ICmpInst::Predicate Pred, const SCEV *LHS,
9625
9626
const SCEV *RHS,
9626
- const Value *FoundCondValue, bool Inverse) {
9627
+ const Value *FoundCondValue, bool Inverse,
9628
+ const Instruction *Context) {
9627
9629
if (!PendingLoopPredicates.insert (FoundCondValue).second )
9628
9630
return false ;
9629
9631
@@ -9634,12 +9636,16 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
9634
9636
if (const BinaryOperator *BO = dyn_cast<BinaryOperator>(FoundCondValue)) {
9635
9637
if (BO->getOpcode () == Instruction::And) {
9636
9638
if (!Inverse)
9637
- return isImpliedCond (Pred, LHS, RHS, BO->getOperand (0 ), Inverse) ||
9638
- isImpliedCond (Pred, LHS, RHS, BO->getOperand (1 ), Inverse);
9639
+ return isImpliedCond (Pred, LHS, RHS, BO->getOperand (0 ), Inverse,
9640
+ Context) ||
9641
+ isImpliedCond (Pred, LHS, RHS, BO->getOperand (1 ), Inverse,
9642
+ Context);
9639
9643
} else if (BO->getOpcode () == Instruction::Or) {
9640
9644
if (Inverse)
9641
- return isImpliedCond (Pred, LHS, RHS, BO->getOperand (0 ), Inverse) ||
9642
- isImpliedCond (Pred, LHS, RHS, BO->getOperand (1 ), Inverse);
9645
+ return isImpliedCond (Pred, LHS, RHS, BO->getOperand (0 ), Inverse,
9646
+ Context) ||
9647
+ isImpliedCond (Pred, LHS, RHS, BO->getOperand (1 ), Inverse,
9648
+ Context);
9643
9649
}
9644
9650
}
9645
9651
@@ -9657,14 +9663,14 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
9657
9663
const SCEV *FoundLHS = getSCEV (ICI->getOperand (0 ));
9658
9664
const SCEV *FoundRHS = getSCEV (ICI->getOperand (1 ));
9659
9665
9660
- return isImpliedCond (Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS);
9666
+ return isImpliedCond (Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, Context );
9661
9667
}
9662
9668
9663
9669
bool ScalarEvolution::isImpliedCond (ICmpInst::Predicate Pred, const SCEV *LHS,
9664
9670
const SCEV *RHS,
9665
9671
ICmpInst::Predicate FoundPred,
9666
- const SCEV *FoundLHS,
9667
- const SCEV *FoundRHS ) {
9672
+ const SCEV *FoundLHS, const SCEV *FoundRHS,
9673
+ const Instruction *Context ) {
9668
9674
// Balance the types.
9669
9675
if (getTypeSizeInBits (LHS->getType ()) <
9670
9676
getTypeSizeInBits (FoundLHS->getType ())) {
@@ -9708,24 +9714,24 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
9708
9714
9709
9715
// Check whether the found predicate is the same as the desired predicate.
9710
9716
if (FoundPred == Pred)
9711
- return isImpliedCondOperands (Pred, LHS, RHS, FoundLHS, FoundRHS);
9717
+ return isImpliedCondOperands (Pred, LHS, RHS, FoundLHS, FoundRHS, Context );
9712
9718
9713
9719
// Check whether swapping the found predicate makes it the same as the
9714
9720
// desired predicate.
9715
9721
if (ICmpInst::getSwappedPredicate (FoundPred) == Pred) {
9716
9722
if (isa<SCEVConstant>(RHS))
9717
- return isImpliedCondOperands (Pred, LHS, RHS, FoundRHS, FoundLHS);
9723
+ return isImpliedCondOperands (Pred, LHS, RHS, FoundRHS, FoundLHS, Context );
9718
9724
else
9719
- return isImpliedCondOperands (ICmpInst::getSwappedPredicate (Pred),
9720
- RHS, LHS, FoundLHS, FoundRHS);
9725
+ return isImpliedCondOperands (ICmpInst::getSwappedPredicate (Pred), RHS,
9726
+ LHS, FoundLHS, FoundRHS, Context );
9721
9727
}
9722
9728
9723
9729
// Unsigned comparison is the same as signed comparison when both the operands
9724
9730
// are non-negative.
9725
9731
if (CmpInst::isUnsigned (FoundPred) &&
9726
9732
CmpInst::getSignedPredicate (FoundPred) == Pred &&
9727
9733
isKnownNonNegative (FoundLHS) && isKnownNonNegative (FoundRHS))
9728
- return isImpliedCondOperands (Pred, LHS, RHS, FoundLHS, FoundRHS);
9734
+ return isImpliedCondOperands (Pred, LHS, RHS, FoundLHS, FoundRHS, Context );
9729
9735
9730
9736
// Check if we can make progress by sharpening ranges.
9731
9737
if (FoundPred == ICmpInst::ICMP_NE &&
@@ -9762,8 +9768,8 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
9762
9768
case ICmpInst::ICMP_UGE:
9763
9769
// We know V `Pred` SharperMin. If this implies LHS `Pred`
9764
9770
// RHS, we're done.
9765
- if (isImpliedCondOperands (Pred, LHS, RHS, V,
9766
- getConstant (SharperMin) ))
9771
+ if (isImpliedCondOperands (Pred, LHS, RHS, V, getConstant (SharperMin),
9772
+ Context ))
9767
9773
return true ;
9768
9774
LLVM_FALLTHROUGH;
9769
9775
@@ -9778,22 +9784,23 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
9778
9784
//
9779
9785
// If V `Pred` Min implies LHS `Pred` RHS, we're done.
9780
9786
9781
- if (isImpliedCondOperands (Pred, LHS, RHS, V, getConstant (Min)))
9787
+ if (isImpliedCondOperands (Pred, LHS, RHS, V, getConstant (Min),
9788
+ Context))
9782
9789
return true ;
9783
9790
break ;
9784
9791
9785
9792
// `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
9786
9793
case ICmpInst::ICMP_SLE:
9787
9794
case ICmpInst::ICMP_ULE:
9788
9795
if (isImpliedCondOperands (CmpInst::getSwappedPredicate (Pred), RHS,
9789
- LHS, V, getConstant (SharperMin)))
9796
+ LHS, V, getConstant (SharperMin), Context ))
9790
9797
return true ;
9791
9798
LLVM_FALLTHROUGH;
9792
9799
9793
9800
case ICmpInst::ICMP_SLT:
9794
9801
case ICmpInst::ICMP_ULT:
9795
9802
if (isImpliedCondOperands (CmpInst::getSwappedPredicate (Pred), RHS,
9796
- LHS, V, getConstant (Min)))
9803
+ LHS, V, getConstant (Min), Context ))
9797
9804
return true ;
9798
9805
break ;
9799
9806
@@ -9807,11 +9814,12 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
9807
9814
// Check whether the actual condition is beyond sufficient.
9808
9815
if (FoundPred == ICmpInst::ICMP_EQ)
9809
9816
if (ICmpInst::isTrueWhenEqual (Pred))
9810
- if (isImpliedCondOperands (Pred, LHS, RHS, FoundLHS, FoundRHS))
9817
+ if (isImpliedCondOperands (Pred, LHS, RHS, FoundLHS, FoundRHS, Context ))
9811
9818
return true ;
9812
9819
if (Pred == ICmpInst::ICMP_NE)
9813
9820
if (!ICmpInst::isTrueWhenEqual (FoundPred))
9814
- if (isImpliedCondOperands (FoundPred, LHS, RHS, FoundLHS, FoundRHS))
9821
+ if (isImpliedCondOperands (FoundPred, LHS, RHS, FoundLHS, FoundRHS,
9822
+ Context))
9815
9823
return true ;
9816
9824
9817
9825
// Otherwise assume the worst.
@@ -9890,6 +9898,51 @@ Optional<APInt> ScalarEvolution::computeConstantDifference(const SCEV *More,
9890
9898
return None;
9891
9899
}
9892
9900
9901
+ bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart (
9902
+ ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
9903
+ const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *Context) {
9904
+ // Try to recognize the following pattern:
9905
+ //
9906
+ // FoundRHS = ...
9907
+ // ...
9908
+ // loop:
9909
+ // FoundLHS = {Start,+,W}
9910
+ // context_bb: // Basic block from the same loop
9911
+ // known(Pred, FoundLHS, FoundRHS)
9912
+ //
9913
+ // If some predicate is known in the context of a loop, it is also known on
9914
+ // each iteration of this loop, including the first iteration. Therefore, in
9915
+ // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
9916
+ // prove the original pred using this fact.
9917
+ if (!Context)
9918
+ return false ;
9919
+ const BasicBlock *ContextBB = Context->getParent ();
9920
+ // Make sure AR varies in the context block.
9921
+ if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
9922
+ const Loop *L = AR->getLoop ();
9923
+ // Make sure that context belongs to the loop and executes on 1st iteration
9924
+ // (if it ever executes at all).
9925
+ if (!L->contains (ContextBB) || !DT.dominates (ContextBB, L->getLoopLatch ()))
9926
+ return false ;
9927
+ if (!isAvailableAtLoopEntry (FoundRHS, AR->getLoop ()))
9928
+ return false ;
9929
+ return isImpliedCondOperands (Pred, LHS, RHS, AR->getStart (), FoundRHS);
9930
+ }
9931
+
9932
+ if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
9933
+ const Loop *L = AR->getLoop ();
9934
+ // Make sure that context belongs to the loop and executes on 1st iteration
9935
+ // (if it ever executes at all).
9936
+ if (!L->contains (ContextBB) || !DT.dominates (ContextBB, L->getLoopLatch ()))
9937
+ return false ;
9938
+ if (!isAvailableAtLoopEntry (FoundLHS, AR->getLoop ()))
9939
+ return false ;
9940
+ return isImpliedCondOperands (Pred, LHS, RHS, FoundLHS, AR->getStart ());
9941
+ }
9942
+
9943
+ return false ;
9944
+ }
9945
+
9893
9946
bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow (
9894
9947
ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
9895
9948
const SCEV *FoundLHS, const SCEV *FoundRHS) {
@@ -10080,13 +10133,18 @@ bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred,
10080
10133
bool ScalarEvolution::isImpliedCondOperands (ICmpInst::Predicate Pred,
10081
10134
const SCEV *LHS, const SCEV *RHS,
10082
10135
const SCEV *FoundLHS,
10083
- const SCEV *FoundRHS) {
10136
+ const SCEV *FoundRHS,
10137
+ const Instruction *Context) {
10084
10138
if (isImpliedCondOperandsViaRanges (Pred, LHS, RHS, FoundLHS, FoundRHS))
10085
10139
return true ;
10086
10140
10087
10141
if (isImpliedCondOperandsViaNoOverflow (Pred, LHS, RHS, FoundLHS, FoundRHS))
10088
10142
return true ;
10089
10143
10144
+ if (isImpliedCondOperandsViaAddRecStart (Pred, LHS, RHS, FoundLHS, FoundRHS,
10145
+ Context))
10146
+ return true ;
10147
+
10090
10148
return isImpliedCondOperandsHelper (Pred, LHS, RHS,
10091
10149
FoundLHS, FoundRHS) ||
10092
10150
// ~x < ~y --> x > y
0 commit comments