Skip to content

Commit 679e95e

Browse files
committed
[SCEV] Infer loop max trip count from memory accesses
Data references in a loop is assumed to not access elements over the statically allocated size. We can therefore infer a loop max trip count from this undefined behavior. This patch is refined from the orignal one (https://reviews.llvm.org/D155049) authored by @Peakulorain.
1 parent 4f54d71 commit 679e95e

File tree

4 files changed

+493
-1
lines changed

4 files changed

+493
-1
lines changed

llvm/include/llvm/Analysis/ScalarEvolution.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,12 @@ class ScalarEvolution {
854854
unsigned getSmallConstantTripMultiple(const Loop *L,
855855
const BasicBlock *ExitingBlock);
856856

857+
/// Return the upper bound of the loop trip count infered from memory access.
858+
/// This can not access bytes starting outside the statically allocated size
859+
/// without being immediate UB. Returns SCEVCouldNotCompute if the trip count
860+
/// could not be inferred.
861+
const SCEV *getConstantMaxTripCountFromMemAccess(const Loop *L);
862+
857863
/// The terms "backedge taken count" and "exit count" are used
858864
/// interchangeably to refer to the number of times the backedge of a loop
859865
/// has executed before the loop is exited.

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 186 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,10 @@ static cl::opt<bool> UseContextForNoWrapFlagInference(
249249
cl::desc("Infer nuw/nsw flags using context where suitable"),
250250
cl::init(true));
251251

252+
static cl::opt<bool> UseMemoryAccessUBForBEInference(
253+
"scalar-evolution-infer-max-trip-count-from-memory-access", cl::Hidden,
254+
cl::desc("Infer loop max trip count from memory access"), cl::init(false));
255+
252256
//===----------------------------------------------------------------------===//
253257
// SCEV class definitions
254258
//===----------------------------------------------------------------------===//
@@ -8135,7 +8139,16 @@ ScalarEvolution::getSmallConstantTripCount(const Loop *L,
81358139
unsigned ScalarEvolution::getSmallConstantMaxTripCount(const Loop *L) {
81368140
const auto *MaxExitCount =
81378141
dyn_cast<SCEVConstant>(getConstantMaxBackedgeTakenCount(L));
8138-
return getConstantTripCount(MaxExitCount);
8142+
unsigned MaxExitCountN = getConstantTripCount(MaxExitCount);
8143+
if (UseMemoryAccessUBForBEInference) {
8144+
auto *MaxInferCount = getConstantMaxTripCountFromMemAccess(L);
8145+
if (auto *InferCount = dyn_cast<SCEVConstant>(MaxInferCount)) {
8146+
unsigned InferValue = InferCount->getValue()->getZExtValue();
8147+
MaxExitCountN =
8148+
MaxExitCountN == 0 ? InferValue : std::min(MaxExitCountN, InferValue);
8149+
}
8150+
}
8151+
return MaxExitCountN;
81398152
}
81408153

81418154
unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L) {
@@ -8190,6 +8203,167 @@ ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
81908203
return getSmallConstantTripMultiple(L, ExitCount);
81918204
}
81928205

8206+
/// Collect all load/store instructions that must be executed in every iteration
8207+
/// of loop \p L .
8208+
static void
8209+
collectExecLoadStoreInsideLoop(const Loop *L, DominatorTree &DT,
8210+
SmallVector<Instruction *, 4> &MemInsts) {
8211+
// It is difficult to tell if the load/store instruction is executed on every
8212+
// iteration inside an irregular loop.
8213+
if (!L->isLoopSimplifyForm() || !L->isInnermost())
8214+
return;
8215+
8216+
// FIXME: To make the case more typical, we only analyze loops that have one
8217+
// exiting block and the block must be the latch. It is easier to capture
8218+
// loops with memory access that will be executed in every iteration.
8219+
const BasicBlock *LoopLatch = L->getLoopLatch();
8220+
assert(LoopLatch && "normal form loop doesn't have a latch");
8221+
if (L->getExitingBlock() != LoopLatch)
8222+
return;
8223+
8224+
// We will not continue if sanitizer is enabled.
8225+
const Function *F = LoopLatch->getParent();
8226+
if (F->hasFnAttribute(Attribute::SanitizeAddress) ||
8227+
F->hasFnAttribute(Attribute::SanitizeThread) ||
8228+
F->hasFnAttribute(Attribute::SanitizeMemory) ||
8229+
F->hasFnAttribute(Attribute::SanitizeHWAddress) ||
8230+
F->hasFnAttribute(Attribute::SanitizeMemTag))
8231+
return;
8232+
8233+
for (auto *BB : L->getBlocks()) {
8234+
// We need to make sure that max execution time of MemAccessBB in loop
8235+
// represents latch max excution time. The BB below should be skipped:
8236+
// Entry
8237+
// │
8238+
// ┌─────▼─────┐
8239+
// │Loop Header◄─────┐
8240+
// └──┬──────┬─┘ │
8241+
// │ │ │
8242+
// ┌────────▼──┐ ┌─▼─────┐ │
8243+
// │MemAccessBB│ │OtherBB│ │
8244+
// └────────┬──┘ └─┬─────┘ │
8245+
// │ │ │
8246+
// ┌─▼──────▼─┐ │
8247+
// │Loop Latch├─────┘
8248+
// └────┬─────┘
8249+
// ▼
8250+
// Exit
8251+
if (!DT.dominates(BB, LoopLatch))
8252+
continue;
8253+
8254+
for (Instruction &I : *BB) {
8255+
if (isa<LoadInst>(&I) || isa<StoreInst>(&I))
8256+
MemInsts.push_back(&I);
8257+
}
8258+
}
8259+
}
8260+
8261+
/// Return a SCEV representing the memory size of pointer \p V .
8262+
static const SCEV *getCertainSizeOfMem(const SCEV *V, Type *RTy,
8263+
const DataLayout &DL,
8264+
const TargetLibraryInfo &TLI,
8265+
ScalarEvolution *SE) {
8266+
const SCEVUnknown *PtrBase = dyn_cast<SCEVUnknown>(V);
8267+
if (!PtrBase)
8268+
return nullptr;
8269+
Value *Ptr = PtrBase->getValue();
8270+
uint64_t Size = 0;
8271+
if (!llvm::getObjectSize(Ptr, Size, DL, &TLI))
8272+
return nullptr;
8273+
return SE->getConstant(RTy, Size);
8274+
}
8275+
8276+
/// Get the range of given index represented by \p AddRec.
8277+
static const SCEV *getIndexRange(const SCEVAddRecExpr *AddRec,
8278+
ScalarEvolution *SE) {
8279+
const SCEV *Range = SE->getConstant(SE->getUnsignedRangeMax(AddRec) -
8280+
SE->getUnsignedRangeMin(AddRec));
8281+
const SCEV *Step = AddRec->getStepRecurrence(*SE);
8282+
return SE->getUDivCeilSCEV(Range, Step);
8283+
}
8284+
8285+
/// Check whether the index can wrap and if we can still infer max trip count
8286+
/// given the max trip count inferred from memory access.
8287+
static const SCEV *checkIndexWrap(Value *Ptr, ScalarEvolution *SE,
8288+
const SCEVConstant *MaxExecCount) {
8289+
SmallVector<const SCEV *> InferCountColl;
8290+
auto *PtrGEP = dyn_cast<GetElementPtrInst>(Ptr);
8291+
if (!PtrGEP)
8292+
return SE->getCouldNotCompute();
8293+
for (Value *Index : PtrGEP->indices()) {
8294+
Value *V = Index;
8295+
if (isa<ZExtInst>(V) || isa<SExtInst>(V))
8296+
V = cast<Instruction>(Index)->getOperand(0);
8297+
auto *SCEV = SE->getSCEV(V);
8298+
if (isa<SCEVCouldNotCompute>(SCEV))
8299+
return SE->getCouldNotCompute();
8300+
auto *AddRec = dyn_cast<SCEVAddRecExpr>(SCEV);
8301+
if (!AddRec)
8302+
continue;
8303+
auto *IndexRange = getIndexRange(AddRec, SE);
8304+
if (AddRec->hasNoSelfWrap()) {
8305+
InferCountColl.push_back(
8306+
SE->getUMinFromMismatchedTypes(IndexRange, MaxExecCount));
8307+
} else {
8308+
auto *IndexRangeC = dyn_cast<SCEVConstant>(IndexRange);
8309+
if (!IndexRangeC)
8310+
continue;
8311+
if (MaxExecCount->getValue()->getZExtValue() >
8312+
IndexRangeC->getValue()->getZExtValue())
8313+
InferCountColl.push_back(IndexRange);
8314+
else
8315+
InferCountColl.push_back(MaxExecCount);
8316+
}
8317+
}
8318+
8319+
if (InferCountColl.empty())
8320+
return SE->getCouldNotCompute();
8321+
8322+
return SE->getUMinFromMismatchedTypes(InferCountColl);
8323+
}
8324+
8325+
const SCEV *
8326+
ScalarEvolution::getConstantMaxTripCountFromMemAccess(const Loop *L) {
8327+
SmallVector<Instruction *, 4> MemInsts;
8328+
collectExecLoadStoreInsideLoop(L, DT, MemInsts);
8329+
8330+
SmallVector<const SCEV *> InferCountColl;
8331+
const DataLayout &DL = getDataLayout();
8332+
8333+
for (Instruction *I : MemInsts) {
8334+
Value *Ptr = getLoadStorePointerOperand(I);
8335+
assert(Ptr && "empty pointer operand");
8336+
auto *AddRec = dyn_cast<SCEVAddRecExpr>(getSCEV(Ptr));
8337+
if (!AddRec || !AddRec->isAffine())
8338+
continue;
8339+
const SCEV *PtrBase = getPointerBase(AddRec);
8340+
const SCEV *Step = AddRec->getStepRecurrence(*this);
8341+
const SCEV *MemSize =
8342+
getCertainSizeOfMem(PtrBase, Step->getType(), DL, TLI, this);
8343+
if (!MemSize)
8344+
continue;
8345+
// Now we can infer a max execution time by MemLength/StepLength.
8346+
auto *MaxExecCount = dyn_cast<SCEVConstant>(getUDivCeilSCEV(MemSize, Step));
8347+
if (!MaxExecCount || MaxExecCount->getAPInt().getActiveBits() > 32)
8348+
continue;
8349+
// Now we check the wrap. We can still explore the max trip count in the
8350+
// following two cases:
8351+
// 1. If the index can potentially wrap but the max trip count inferred from
8352+
// memory access is within the range of index.
8353+
// 2. If the index can't wrap, then the max trip count is:
8354+
// min(range of index, max value inferred from memory access).
8355+
auto *Res = checkIndexWrap(Ptr, this, MaxExecCount);
8356+
if (isa<SCEVCouldNotCompute>(Res))
8357+
continue;
8358+
InferCountColl.push_back(Res);
8359+
}
8360+
8361+
if (InferCountColl.empty())
8362+
return getCouldNotCompute();
8363+
8364+
return getUMinFromMismatchedTypes(InferCountColl);
8365+
}
8366+
81938367
const SCEV *ScalarEvolution::getExitCount(const Loop *L,
81948368
const BasicBlock *ExitingBlock,
81958369
ExitCountKind Kind) {
@@ -13477,6 +13651,17 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
1347713651
OS << ": ";
1347813652
OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
1347913653
}
13654+
13655+
if (UseMemoryAccessUBForBEInference) {
13656+
unsigned SmallMaxTrip = SE->getSmallConstantMaxTripCount(L);
13657+
OS << "Loop ";
13658+
L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13659+
OS << ": ";
13660+
if (SmallMaxTrip)
13661+
OS << "Small constant max trip is " << SmallMaxTrip << "\n";
13662+
else
13663+
OS << "Small constant max trip couldn't be computed.\n";
13664+
}
1348013665
}
1348113666

1348213667
namespace llvm {
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py
2+
; RUN: opt < %s -disable-output "-passes=print<scalar-evolution>" -scalar-evolution-classify-expressions=0 -scalar-evolution-infer-max-trip-count-from-memory-access 2>&1 | FileCheck %s
3+
4+
define void @ComputeMaxTripCountFromArrayIdxWrap(i32 signext %len) {
5+
; CHECK-LABEL: 'ComputeMaxTripCountFromArrayIdxWrap'
6+
; CHECK-NEXT: Determining loop execution counts for: @ComputeMaxTripCountFromArrayIdxWrap
7+
; CHECK-NEXT: Loop %for.body: backedge-taken count is (-1 + %len)
8+
; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is 2147483646
9+
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %len)
10+
; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %len)
11+
; CHECK-NEXT: Predicates:
12+
; CHECK-NEXT: Loop %for.body: Trip multiple is 1
13+
; CHECK-NEXT: Loop %for.body: Small constant max trip is 255
14+
;
15+
entry:
16+
%a = alloca [256 x i32], align 4
17+
%cmp4 = icmp sgt i32 %len, 0
18+
br i1 %cmp4, label %for.body.preheader, label %for.cond.cleanup
19+
20+
for.body.preheader:
21+
br label %for.body
22+
23+
for.cond.cleanup.loopexit:
24+
br label %for.cond.cleanup
25+
26+
for.cond.cleanup:
27+
ret void
28+
29+
for.body:
30+
%iv = phi i8 [ %inc, %for.body ], [ 0, %for.body.preheader ]
31+
%idxprom = zext i8 %iv to i64
32+
%arrayidx = getelementptr inbounds [256 x i32], [256 x i32]* %a, i64 0, i64 %idxprom
33+
store i32 0, i32* %arrayidx, align 4
34+
%inc = add nuw i8 %iv, 1
35+
%inc_zext = zext i8 %inc to i32
36+
%cmp = icmp slt i32 %inc_zext, %len
37+
br i1 %cmp, label %for.body, label %for.cond.cleanup.loopexit
38+
}
39+
40+
define void @ComputeMaxTripCountFromArrayIdxWrap2(i32 signext %len) {
41+
; CHECK-LABEL: 'ComputeMaxTripCountFromArrayIdxWrap2'
42+
; CHECK-NEXT: Determining loop execution counts for: @ComputeMaxTripCountFromArrayIdxWrap2
43+
; CHECK-NEXT: Loop %for.body: backedge-taken count is (-1 + %len)
44+
; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is 2147483646
45+
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %len)
46+
; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %len)
47+
; CHECK-NEXT: Predicates:
48+
; CHECK-NEXT: Loop %for.body: Trip multiple is 1
49+
; CHECK-NEXT: Loop %for.body: Small constant max trip is 127
50+
;
51+
entry:
52+
%a = alloca [127 x i32], align 4
53+
%cmp4 = icmp sgt i32 %len, 0
54+
br i1 %cmp4, label %for.body.preheader, label %for.cond.cleanup
55+
56+
for.body.preheader:
57+
br label %for.body
58+
59+
for.cond.cleanup.loopexit:
60+
br label %for.cond.cleanup
61+
62+
for.cond.cleanup:
63+
ret void
64+
65+
for.body:
66+
%iv = phi i8 [ %inc, %for.body ], [ 0, %for.body.preheader ]
67+
%idxprom = zext i8 %iv to i64
68+
%arrayidx = getelementptr inbounds [127 x i32], [127 x i32]* %a, i64 0, i64 %idxprom
69+
store i32 0, i32* %arrayidx, align 4
70+
%inc = add nuw i8 %iv, 1
71+
%inc_zext = zext i8 %inc to i32
72+
%cmp = icmp slt i32 %inc_zext, %len
73+
br i1 %cmp, label %for.body, label %for.cond.cleanup.loopexit
74+
}
75+
76+
define void @ComputeMaxTripCountFromArrayIdxWrap3(i32 signext %len) {
77+
; CHECK-LABEL: 'ComputeMaxTripCountFromArrayIdxWrap3'
78+
; CHECK-NEXT: Determining loop execution counts for: @ComputeMaxTripCountFromArrayIdxWrap3
79+
; CHECK-NEXT: Loop %for.body: backedge-taken count is (-1 + %len)
80+
; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is 2147483646
81+
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %len)
82+
; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %len)
83+
; CHECK-NEXT: Predicates:
84+
; CHECK-NEXT: Loop %for.body: Trip multiple is 1
85+
; CHECK-NEXT: Loop %for.body: Small constant max trip is 20
86+
;
87+
entry:
88+
%a = alloca [20 x i32], align 4
89+
%cmp4 = icmp sgt i32 %len, 0
90+
br i1 %cmp4, label %for.body.preheader, label %for.cond.cleanup
91+
92+
for.body.preheader:
93+
br label %for.body
94+
95+
for.cond.cleanup.loopexit:
96+
br label %for.cond.cleanup
97+
98+
for.cond.cleanup:
99+
ret void
100+
101+
for.body:
102+
%iv = phi i8 [ %inc, %for.body ], [ 0, %for.body.preheader ]
103+
%idxprom = zext i8 %iv to i64
104+
%arrayidx = getelementptr inbounds [20 x i32], [20 x i32]* %a, i64 0, i64 %idxprom
105+
store i32 0, i32* %arrayidx, align 4
106+
%inc = add nuw nsw i8 %iv, 1
107+
%inc_zext = zext i8 %inc to i32
108+
%cmp = icmp slt i32 %inc_zext, %len
109+
br i1 %cmp, label %for.body, label %for.cond.cleanup.loopexit
110+
}

0 commit comments

Comments
 (0)