-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[SCEV] Add predicate in SolveLinEq to ensure B is a multiple of A. #108777
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
c91b942
2832a8a
fcf36c4
f23bddc
c312060
d6522f9
7624c7d
f648176
e36de7d
ee755af
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10129,8 +10129,11 @@ const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const { | |
| /// A and B isn't important. | ||
| /// | ||
| /// If the equation does not have a solution, SCEVCouldNotCompute is returned. | ||
| static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, | ||
| ScalarEvolution &SE) { | ||
| static const SCEV * | ||
| SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, | ||
| SmallPtrSetImpl<const SCEVPredicate *> *Predicates, | ||
|
|
||
| ScalarEvolution &SE) { | ||
| uint32_t BW = A.getBitWidth(); | ||
| assert(BW == SE.getTypeSizeInBits(B->getType())); | ||
| assert(A != 0 && "A must be non-zero."); | ||
|
|
@@ -10146,8 +10149,20 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, | |
| // | ||
| // B is divisible by D if and only if the multiplicity of prime factor 2 for B | ||
| // is not less than multiplicity of this prime factor for D. | ||
| if (SE.getMinTrailingZeros(B) < Mult2) | ||
| return SE.getCouldNotCompute(); | ||
| if (SE.getMinTrailingZeros(B) < Mult2) { | ||
| if (!Predicates) | ||
| return SE.getCouldNotCompute(); | ||
| // Try to add a predicate ensuring B is a multiple of 1 << Mult2. | ||
| const SCEV *URem = | ||
| SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2))); | ||
| const SCEV *Zero = SE.getZero(B->getType()); | ||
| assert(!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero) && | ||
| "No remainder for 1 << Mult2 but missed by getTrailingBits?"); | ||
| // Avoid adding a predicate that is known to be false. | ||
| if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero)) | ||
| return SE.getCouldNotCompute(); | ||
| Predicates->insert(SE.getEqualPredicate(URem, Zero)); | ||
|
||
| } | ||
|
|
||
| // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic | ||
| // modulo (N / D). | ||
|
|
@@ -10577,8 +10592,9 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V, | |
| // Solve the general equation. | ||
| if (!StepC || StepC->getValue()->isZero()) | ||
| return getCouldNotCompute(); | ||
| const SCEV *E = SolveLinEquationWithOverflow(StepC->getAPInt(), | ||
| getNegativeSCEV(Start), *this); | ||
| const SCEV *E = SolveLinEquationWithOverflow( | ||
| StepC->getAPInt(), getNegativeSCEV(Start), | ||
| AllowPredicates ? &Predicates : nullptr, *this); | ||
|
|
||
| const SCEV *M = E; | ||
| if (E != getCouldNotCompute()) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1580,6 +1580,15 @@ define i32 @ptr_induction_ult_2(ptr %a, ptr %b) { | |
| ; CHECK-NEXT: Loop %loop: Unpredictable backedge-taken count. | ||
| ; CHECK-NEXT: Loop %loop: Unpredictable constant max backedge-taken count. | ||
| ; CHECK-NEXT: Loop %loop: Unpredictable symbolic max backedge-taken count. | ||
| ; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is (((-1 * (ptrtoint ptr %a to i64)) + (ptrtoint ptr %b to i64)) /u 4) | ||
| ; CHECK-NEXT: Predicates: | ||
| ; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %b to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %a to i64) to i2))) to i64) == 0 | ||
| ; CHECK-NEXT: Loop %loop: Predicated constant max backedge-taken count is i64 4611686018427387903 | ||
| ; CHECK-NEXT: Predicates: | ||
| ; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %b to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %a to i64) to i2))) to i64) == 0 | ||
| ; CHECK-NEXT: Loop %loop: Predicated symbolic max backedge-taken count is (((-1 * (ptrtoint ptr %a to i64)) + (ptrtoint ptr %b to i64)) /u 4) | ||
| ; CHECK-NEXT: Predicates: | ||
| ; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %b to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %a to i64) to i2))) to i64) == 0 | ||
| ; | ||
| entry: | ||
| %cmp.6 = icmp ult ptr %a, %b | ||
|
|
@@ -1606,6 +1615,15 @@ define i32 @ptr_induction_ult_3_step_6(ptr %a, ptr %b) { | |
| ; CHECK-NEXT: Loop %loop: Unpredictable backedge-taken count. | ||
| ; CHECK-NEXT: Loop %loop: Unpredictable constant max backedge-taken count. | ||
| ; CHECK-NEXT: Loop %loop: Unpredictable symbolic max backedge-taken count. | ||
| ; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is (((3074457345618258603 * (ptrtoint ptr %b to i64)) + (-3074457345618258603 * (ptrtoint ptr %a to i64))) /u 2) | ||
| ; CHECK-NEXT: Predicates: | ||
| ; CHECK-NEXT: Equal predicate: (zext i1 (trunc i64 ((-1 * (ptrtoint ptr %a to i64)) + (ptrtoint ptr %b to i64)) to i1) to i64) == 0 | ||
| ; CHECK-NEXT: Loop %loop: Predicated constant max backedge-taken count is i64 9223372036854775807 | ||
| ; CHECK-NEXT: Predicates: | ||
| ; CHECK-NEXT: Equal predicate: (zext i1 (trunc i64 ((-1 * (ptrtoint ptr %a to i64)) + (ptrtoint ptr %b to i64)) to i1) to i64) == 0 | ||
| ; CHECK-NEXT: Loop %loop: Predicated symbolic max backedge-taken count is (((3074457345618258603 * (ptrtoint ptr %b to i64)) + (-3074457345618258603 * (ptrtoint ptr %a to i64))) /u 2) | ||
| ; CHECK-NEXT: Predicates: | ||
| ; CHECK-NEXT: Equal predicate: (zext i1 (trunc i64 ((-1 * (ptrtoint ptr %a to i64)) + (ptrtoint ptr %b to i64)) to i1) to i64) == 0 | ||
| ; | ||
| entry: | ||
| %cmp.6 = icmp ult ptr %a, %b | ||
|
|
@@ -1705,10 +1723,9 @@ exit: | |
| ret void | ||
| } | ||
|
|
||
| ; TODO: It feels like we should be able to calculate the symbolic max | ||
| ; exit count for the loop.inc block here, in the same way as | ||
| ; ptr_induction_eq_1. The problem seems to be in howFarToZero when the | ||
| ; ControlsOnlyExit is set to false. | ||
| ; %a and %b may not have the same alignment, so the loop may only via the early | ||
|
||
| ; exit when %ptr.iv > %b. The predicated exit count for the latch can be | ||
| ; computed by adding a predicate. | ||
| define void @ptr_induction_early_exit_eq_1(ptr %a, ptr %b, ptr %c) { | ||
| ; CHECK-LABEL: 'ptr_induction_early_exit_eq_1' | ||
| ; CHECK-NEXT: Classifying expressions for: @ptr_induction_early_exit_eq_1 | ||
|
|
@@ -1722,10 +1739,24 @@ define void @ptr_induction_early_exit_eq_1(ptr %a, ptr %b, ptr %c) { | |
| ; CHECK-NEXT: Loop %loop: <multiple exits> Unpredictable backedge-taken count. | ||
| ; CHECK-NEXT: exit count for loop: ***COULDNOTCOMPUTE*** | ||
| ; CHECK-NEXT: exit count for loop.inc: ***COULDNOTCOMPUTE*** | ||
| ; CHECK-NEXT: predicated exit count for loop.inc: ((-8 + (-1 * (ptrtoint ptr %a to i64)) + (ptrtoint ptr %b to i64)) /u 8) | ||
| ; CHECK-NEXT: Predicates: | ||
| ; CHECK-NEXT: Equal predicate: (zext i3 ((trunc i64 (ptrtoint ptr %b to i64) to i3) + (-1 * (trunc i64 (ptrtoint ptr %a to i64) to i3))) to i64) == 0 | ||
| ; CHECK-EMPTY: | ||
| ; CHECK-NEXT: Loop %loop: Unpredictable constant max backedge-taken count. | ||
| ; CHECK-NEXT: Loop %loop: Unpredictable symbolic max backedge-taken count. | ||
| ; CHECK-NEXT: symbolic max exit count for loop: ***COULDNOTCOMPUTE*** | ||
| ; CHECK-NEXT: symbolic max exit count for loop.inc: ***COULDNOTCOMPUTE*** | ||
| ; CHECK-NEXT: predicated symbolic max exit count for loop.inc: ((-8 + (-1 * (ptrtoint ptr %a to i64)) + (ptrtoint ptr %b to i64)) /u 8) | ||
| ; CHECK-NEXT: Predicates: | ||
| ; CHECK-NEXT: Equal predicate: (zext i3 ((trunc i64 (ptrtoint ptr %b to i64) to i3) + (-1 * (trunc i64 (ptrtoint ptr %a to i64) to i3))) to i64) == 0 | ||
| ; CHECK-EMPTY: | ||
| ; CHECK-NEXT: Loop %loop: Predicated constant max backedge-taken count is i64 2305843009213693951 | ||
| ; CHECK-NEXT: Predicates: | ||
| ; CHECK-NEXT: Equal predicate: (zext i3 ((trunc i64 (ptrtoint ptr %b to i64) to i3) + (-1 * (trunc i64 (ptrtoint ptr %a to i64) to i3))) to i64) == 0 | ||
| ; CHECK-NEXT: Loop %loop: Predicated symbolic max backedge-taken count is ((-8 + (-1 * (ptrtoint ptr %a to i64)) + (ptrtoint ptr %b to i64)) /u 8) | ||
| ; CHECK-NEXT: Predicates: | ||
| ; CHECK-NEXT: Equal predicate: (zext i3 ((trunc i64 (ptrtoint ptr %b to i64) to i3) + (-1 * (trunc i64 (ptrtoint ptr %a to i64) to i3))) to i64) == 0 | ||
| ; | ||
| entry: | ||
| %cmp = icmp eq ptr %a, %b | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we maybe try to prove that (urem B, A) == 0 before resorting to the predicate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC that check shouldn't be true, otherwise we are missing some logic in
getMinTrailingZeros? Added an assert that the check isn't true.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the reframing of urem (B, 1 << Mult2) == 0, I agree with your comment, but I think we could reasonable see divergence with the original urem(B,A) check. Those are just different enough.
My suggestion is basically, can we prove the original fact without using the trailing bits proof strategy? Said differently. what if D = gcd(A, N) is something like 3^M?
Hm, though now I see that the multiplicative inverse code below would need updated too. I'd missed that originally.
(Totally fine to move forward here, this is purely a possible followup.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I stand corrected, the assertion actually uncovered a case where getMinTrailingZeros returns a worse result and I think catching this in getMinTrailingZeros would require matching a specific pattern (https://alive2.llvm.org/ce/z/t3A5X2)
I updated the code to check if the URem is zero if the trailing bits check failed, as we need to build the expression there already. Test is in
llvm/test/Analysis/ScalarEvolution/trip-count-urem.ll