Skip to content

Commit c7fbe38

Browse files
authored
[SCEV] Pass loop pred branch as context instruction to getMinTrailingZ. (#160941)
When computing the backedge taken count, we know that the expression must be valid just before we enter the loop. Using the terminator of the loop predecessor as context instruction for getConstantMultiple, getMinTrailingZeros allows using information from things like alignment assumptions. When a context instruction is used, the result is not cached, as it is only valid at the specific context instruction. Compile-time looks neutral: http://llvm-compile-time-tracker.com/compare.php?from=9be276ec75c087595ebb62fe11b35c1a90371a49&to=745980f5e1c8094ea1293cd145d0ef1390f03029&stat=instructions:u No impact on llvm-opt-benchmark (dtcxzyw/llvm-opt-benchmark#2867), but leads to additonal unrolling in ~90 files across a C/C++ based corpus including LLVM on AArch64 using libc++ (which emits alignment assumptions for things like std::vector::begin). PR: #160941
1 parent 482cd5f commit c7fbe38

File tree

3 files changed

+57
-55
lines changed

3 files changed

+57
-55
lines changed

llvm/include/llvm/Analysis/ScalarEvolution.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,10 +1002,14 @@ class ScalarEvolution {
10021002
/// (at every loop iteration). It is, at the same time, the minimum number
10031003
/// of times S is divisible by 2. For example, given {4,+,8} it returns 2.
10041004
/// If S is guaranteed to be 0, it returns the bitwidth of S.
1005-
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S);
1005+
/// If \p CtxI is not nullptr, return a constant multiple valid at \p CtxI.
1006+
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S,
1007+
const Instruction *CtxI = nullptr);
10061008

1007-
/// Returns the max constant multiple of S.
1008-
LLVM_ABI APInt getConstantMultiple(const SCEV *S);
1009+
/// Returns the max constant multiple of S. If \p CtxI is not nullptr, return
1010+
/// a constant multiple valid at \p CtxI.
1011+
LLVM_ABI APInt getConstantMultiple(const SCEV *S,
1012+
const Instruction *CtxI = nullptr);
10091013

10101014
// Returns the max constant multiple of S. If S is exactly 0, return 1.
10111015
LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S);
@@ -1525,8 +1529,10 @@ class ScalarEvolution {
15251529
/// Return the Value set from which the SCEV expr is generated.
15261530
ArrayRef<Value *> getSCEVValues(const SCEV *S);
15271531

1528-
/// Private helper method for the getConstantMultiple method.
1529-
APInt getConstantMultipleImpl(const SCEV *S);
1532+
/// Private helper method for the getConstantMultiple method. If \p CtxI is
1533+
/// not nullptr, return a constant multiple valid at \p CtxI.
1534+
APInt getConstantMultipleImpl(const SCEV *S,
1535+
const Instruction *Ctx = nullptr);
15301536

15311537
/// Information about the number of times a particular loop exit may be
15321538
/// reached before exiting the loop.

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6351,61 +6351,62 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
63516351
return getGEPExpr(GEP, IndexExprs);
63526352
}
63536353

6354-
APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
6354+
APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
6355+
const Instruction *CtxI) {
63556356
uint64_t BitWidth = getTypeSizeInBits(S->getType());
63566357
auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
63576358
return TrailingZeros >= BitWidth
63586359
? APInt::getZero(BitWidth)
63596360
: APInt::getOneBitSet(BitWidth, TrailingZeros);
63606361
};
6361-
auto GetGCDMultiple = [this](const SCEVNAryExpr *N) {
6362+
auto GetGCDMultiple = [this, CtxI](const SCEVNAryExpr *N) {
63626363
// The result is GCD of all operands results.
6363-
APInt Res = getConstantMultiple(N->getOperand(0));
6364+
APInt Res = getConstantMultiple(N->getOperand(0), CtxI);
63646365
for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
63656366
Res = APIntOps::GreatestCommonDivisor(
6366-
Res, getConstantMultiple(N->getOperand(I)));
6367+
Res, getConstantMultiple(N->getOperand(I), CtxI));
63676368
return Res;
63686369
};
63696370

63706371
switch (S->getSCEVType()) {
63716372
case scConstant:
63726373
return cast<SCEVConstant>(S)->getAPInt();
63736374
case scPtrToInt:
6374-
return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand());
6375+
return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand(), CtxI);
63756376
case scUDivExpr:
63766377
case scVScale:
63776378
return APInt(BitWidth, 1);
63786379
case scTruncate: {
63796380
// Only multiples that are a power of 2 will hold after truncation.
63806381
const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6381-
uint32_t TZ = getMinTrailingZeros(T->getOperand());
6382+
uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI);
63826383
return GetShiftedByZeros(TZ);
63836384
}
63846385
case scZeroExtend: {
63856386
const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6386-
return getConstantMultiple(Z->getOperand()).zext(BitWidth);
6387+
return getConstantMultiple(Z->getOperand(), CtxI).zext(BitWidth);
63876388
}
63886389
case scSignExtend: {
63896390
// Only multiples that are a power of 2 will hold after sext.
63906391
const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6391-
uint32_t TZ = getMinTrailingZeros(E->getOperand());
6392+
uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI);
63926393
return GetShiftedByZeros(TZ);
63936394
}
63946395
case scMulExpr: {
63956396
const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
63966397
if (M->hasNoUnsignedWrap()) {
63976398
// The result is the product of all operand results.
6398-
APInt Res = getConstantMultiple(M->getOperand(0));
6399+
APInt Res = getConstantMultiple(M->getOperand(0), CtxI);
63996400
for (const SCEV *Operand : M->operands().drop_front())
6400-
Res = Res * getConstantMultiple(Operand);
6401+
Res = Res * getConstantMultiple(Operand, CtxI);
64016402
return Res;
64026403
}
64036404

64046405
// If there are no wrap guarentees, find the trailing zeros, which is the
64056406
// sum of trailing zeros for all its operands.
64066407
uint32_t TZ = 0;
64076408
for (const SCEV *Operand : M->operands())
6408-
TZ += getMinTrailingZeros(Operand);
6409+
TZ += getMinTrailingZeros(Operand, CtxI);
64096410
return GetShiftedByZeros(TZ);
64106411
}
64116412
case scAddExpr:
@@ -6414,9 +6415,9 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
64146415
if (N->hasNoUnsignedWrap())
64156416
return GetGCDMultiple(N);
64166417
// Find the trailing bits, which is the minimum of its operands.
6417-
uint32_t TZ = getMinTrailingZeros(N->getOperand(0));
6418+
uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI);
64186419
for (const SCEV *Operand : N->operands().drop_front())
6419-
TZ = std::min(TZ, getMinTrailingZeros(Operand));
6420+
TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI));
64206421
return GetShiftedByZeros(TZ);
64216422
}
64226423
case scUMaxExpr:
@@ -6429,7 +6430,7 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
64296430
// ask ValueTracking for known bits
64306431
const SCEVUnknown *U = cast<SCEVUnknown>(S);
64316432
unsigned Known =
6432-
computeKnownBits(U->getValue(), getDataLayout(), &AC, nullptr, &DT)
6433+
computeKnownBits(U->getValue(), getDataLayout(), &AC, CtxI, &DT)
64336434
.countMinTrailingZeros();
64346435
return GetShiftedByZeros(Known);
64356436
}
@@ -6439,12 +6440,18 @@ APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
64396440
llvm_unreachable("Unknown SCEV kind!");
64406441
}
64416442

6442-
APInt ScalarEvolution::getConstantMultiple(const SCEV *S) {
6443+
APInt ScalarEvolution::getConstantMultiple(const SCEV *S,
6444+
const Instruction *CtxI) {
6445+
// Skip looking up and updating the cache if there is a context instruction,
6446+
// as the result will only be valid in the specified context.
6447+
if (CtxI)
6448+
return getConstantMultipleImpl(S, CtxI);
6449+
64436450
auto I = ConstantMultipleCache.find(S);
64446451
if (I != ConstantMultipleCache.end())
64456452
return I->second;
64466453

6447-
APInt Result = getConstantMultipleImpl(S);
6454+
APInt Result = getConstantMultipleImpl(S, CtxI);
64486455
auto InsertPair = ConstantMultipleCache.insert({S, Result});
64496456
assert(InsertPair.second && "Should insert a new key");
64506457
return InsertPair.first->second;
@@ -6455,8 +6462,9 @@ APInt ScalarEvolution::getNonZeroConstantMultiple(const SCEV *S) {
64556462
return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
64566463
}
64576464

6458-
uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S) {
6459-
return std::min(getConstantMultiple(S).countTrailingZeros(),
6465+
uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S,
6466+
const Instruction *CtxI) {
6467+
return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(),
64606468
(unsigned)getTypeSizeInBits(S->getType()));
64616469
}
64626470

@@ -10243,8 +10251,7 @@ const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
1024310251
static const SCEV *
1024410252
SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
1024510253
SmallVectorImpl<const SCEVPredicate *> *Predicates,
10246-
10247-
ScalarEvolution &SE) {
10254+
ScalarEvolution &SE, const Loop *L) {
1024810255
uint32_t BW = A.getBitWidth();
1024910256
assert(BW == SE.getTypeSizeInBits(B->getType()));
1025010257
assert(A != 0 && "A must be non-zero.");
@@ -10260,7 +10267,12 @@ SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
1026010267
//
1026110268
// B is divisible by D if and only if the multiplicity of prime factor 2 for B
1026210269
// is not less than multiplicity of this prime factor for D.
10263-
if (SE.getMinTrailingZeros(B) < Mult2) {
10270+
unsigned MinTZ = SE.getMinTrailingZeros(B);
10271+
// Try again with the terminator of the loop predecessor for context-specific
10272+
// result, if MinTZ s too small.
10273+
if (MinTZ < Mult2 && L->getLoopPredecessor())
10274+
MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator());
10275+
if (MinTZ < Mult2) {
1026410276
// Check if we can prove there's no remainder using URem.
1026510277
const SCEV *URem =
1026610278
SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
@@ -10708,7 +10720,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
1070810720
return getCouldNotCompute();
1070910721
const SCEV *E = SolveLinEquationWithOverflow(
1071010722
StepC->getAPInt(), getNegativeSCEV(Start),
10711-
AllowPredicates ? &Predicates : nullptr, *this);
10723+
AllowPredicates ? &Predicates : nullptr, *this, L);
1071210724

1071310725
const SCEV *M = E;
1071410726
if (E != getCouldNotCompute()) {

llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -615,22 +615,14 @@ define void @test_ptrs_aligned_by_4_via_assumption(ptr %start, ptr %end) {
615615
; CHECK-LABEL: 'test_ptrs_aligned_by_4_via_assumption'
616616
; CHECK-NEXT: Classifying expressions for: @test_ptrs_aligned_by_4_via_assumption
617617
; CHECK-NEXT: %iv = phi ptr [ %start, %entry ], [ %iv.next, %loop ]
618-
; CHECK-NEXT: --> {%start,+,4}<%loop> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %loop: Computable }
618+
; CHECK-NEXT: --> {%start,+,4}<%loop> U: full-set S: full-set Exits: ((4 * ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4))<nuw> + %start) LoopDispositions: { %loop: Computable }
619619
; CHECK-NEXT: %iv.next = getelementptr i8, ptr %iv, i64 4
620-
; CHECK-NEXT: --> {(4 + %start),+,4}<%loop> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %loop: Computable }
620+
; CHECK-NEXT: --> {(4 + %start),+,4}<%loop> U: full-set S: full-set Exits: (4 + (4 * ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4))<nuw> + %start) LoopDispositions: { %loop: Computable }
621621
; CHECK-NEXT: Determining loop execution counts for: @test_ptrs_aligned_by_4_via_assumption
622-
; CHECK-NEXT: Loop %loop: Unpredictable backedge-taken count.
623-
; CHECK-NEXT: Loop %loop: Unpredictable constant max backedge-taken count.
624-
; CHECK-NEXT: Loop %loop: Unpredictable symbolic max backedge-taken count.
625-
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
626-
; CHECK-NEXT: Predicates:
627-
; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0
628-
; CHECK-NEXT: Loop %loop: Predicated constant max backedge-taken count is i64 4611686018427387903
629-
; CHECK-NEXT: Predicates:
630-
; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0
631-
; CHECK-NEXT: Loop %loop: Predicated symbolic max backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
632-
; CHECK-NEXT: Predicates:
633-
; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0
622+
; CHECK-NEXT: Loop %loop: backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
623+
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i64 4611686018427387903
624+
; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
625+
; CHECK-NEXT: Loop %loop: Trip multiple is 1
634626
;
635627
entry:
636628
call void @llvm.assume(i1 true) [ "align"(ptr %start, i64 4) ]
@@ -652,22 +644,14 @@ define void @test_ptrs_aligned_by_8_via_assumption(ptr %start, ptr %end) {
652644
; CHECK-LABEL: 'test_ptrs_aligned_by_8_via_assumption'
653645
; CHECK-NEXT: Classifying expressions for: @test_ptrs_aligned_by_8_via_assumption
654646
; CHECK-NEXT: %iv = phi ptr [ %start, %entry ], [ %iv.next, %loop ]
655-
; CHECK-NEXT: --> {%start,+,4}<%loop> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %loop: Computable }
647+
; CHECK-NEXT: --> {%start,+,4}<%loop> U: full-set S: full-set Exits: ((4 * ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4))<nuw> + %start) LoopDispositions: { %loop: Computable }
656648
; CHECK-NEXT: %iv.next = getelementptr i8, ptr %iv, i64 4
657-
; CHECK-NEXT: --> {(4 + %start),+,4}<%loop> U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %loop: Computable }
649+
; CHECK-NEXT: --> {(4 + %start),+,4}<%loop> U: full-set S: full-set Exits: (4 + (4 * ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4))<nuw> + %start) LoopDispositions: { %loop: Computable }
658650
; CHECK-NEXT: Determining loop execution counts for: @test_ptrs_aligned_by_8_via_assumption
659-
; CHECK-NEXT: Loop %loop: Unpredictable backedge-taken count.
660-
; CHECK-NEXT: Loop %loop: Unpredictable constant max backedge-taken count.
661-
; CHECK-NEXT: Loop %loop: Unpredictable symbolic max backedge-taken count.
662-
; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
663-
; CHECK-NEXT: Predicates:
664-
; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0
665-
; CHECK-NEXT: Loop %loop: Predicated constant max backedge-taken count is i64 4611686018427387903
666-
; CHECK-NEXT: Predicates:
667-
; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0
668-
; CHECK-NEXT: Loop %loop: Predicated symbolic max backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
669-
; CHECK-NEXT: Predicates:
670-
; CHECK-NEXT: Equal predicate: (zext i2 ((trunc i64 (ptrtoint ptr %end to i64) to i2) + (-1 * (trunc i64 (ptrtoint ptr %start to i64) to i2))) to i64) == 0
651+
; CHECK-NEXT: Loop %loop: backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
652+
; CHECK-NEXT: Loop %loop: constant max backedge-taken count is i64 4611686018427387903
653+
; CHECK-NEXT: Loop %loop: symbolic max backedge-taken count is ((-4 + (-1 * (ptrtoint ptr %start to i64)) + (ptrtoint ptr %end to i64)) /u 4)
654+
; CHECK-NEXT: Loop %loop: Trip multiple is 1
671655
;
672656
entry:
673657
call void @llvm.assume(i1 true) [ "align"(ptr %start, i64 8) ]

0 commit comments

Comments
 (0)