Skip to content

Commit b8ef25a

Browse files
authored
[PGO] Fix zeroed estimated trip count (#167792)
Before PR #152775, `llvm::getLoopEstimatedTripCount` never returned 0. If `llvm::setLoopEstimatedTripCount` were called with 0, it would zero branch weights, causing `llvm::getLoopEstimatedTripCount` to return `std::nullopt`. PR #152775 changed that behavior: if `llvm::setLoopEstimatedTripCount` is called with 0, it sets `llvm.loop.estimated_trip_count` to 0, causing `llvm::getLoopEstimatedTripCount` to return 0. However, it kept documentation saying `llvm::getLoopEstimatedTripCount` returns a positive count. Some passes continue to assume `llvm::getLoopEstimatedTripCount` never returns 0 and crash if it does, as reported in issue #164254. To restore the behavior they expect, this patch changes `llvm::getLoopEstimatedTripCount` to return `std::nullopt` when `llvm.loop.estimated_trip_count` is 0.
1 parent 02c9e89 commit b8ef25a

File tree

6 files changed

+169
-8
lines changed

6 files changed

+169
-8
lines changed

llvm/docs/LangRef.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8063,6 +8063,21 @@ pass should record the new estimates by calling
80638063
loop, ``llvm::getLoopEstimatedTripCount`` returns its value instead of
80648064
estimating the trip count from the loop's ``branch_weights`` metadata.
80658065

8066+
Zero
8067+
""""
8068+
8069+
Some passes set ``llvm.loop.estimated_trip_count`` to 0. For example, after
8070+
peeling 10 or more iterations from a loop with an estimated trip count of 10,
8071+
``llvm.loop.estimated_trip_count`` becomes 0 on the remaining loop. It
8072+
indicates that, each time execution reaches the peeled iterations, execution is
8073+
estimated to exit them without reaching the remaining loop's header.
8074+
8075+
Even if the probability of reaching a loop's header is low, if it is reached, it
8076+
is the start of an iteration. Consequently, some passes historically assume
8077+
that ``llvm::getLoopEstimatedTripCount`` always returns a positive count or
8078+
``std::nullopt``. Thus, it returns ``std::nullopt`` when
8079+
``llvm.loop.estimated_trip_count`` is 0.
8080+
80668081
'``llvm.licm.disable``' Metadata
80678082
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
80688083

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,10 @@ LLVM_ABI void addStringMetadataToLoop(Loop *TheLoop, const char *MDString,
326326
/// - \c std::nullopt, if the implementation is unable to handle the loop form
327327
/// of \p L (e.g., \p L must have a latch block that controls the loop exit).
328328
/// - The value of \c llvm.loop.estimated_trip_count from the loop metadata of
329-
/// \p L, if that metadata is present.
329+
/// \p L, if that metadata is present. In the special case that the value is
330+
/// zero, return \c std::nullopt instead as that is historically what callers
331+
/// expect when a loop is estimated to execute no iterations (i.e., its header
332+
/// is not reached).
330333
/// - Else, a new estimate of the trip count from the latch branch weights of
331334
/// \p L.
332335
///
@@ -353,10 +356,11 @@ getLoopEstimatedTripCount(Loop *L,
353356
/// to handle the loop form of \p L (e.g., \p L must have a latch block that
354357
/// controls the loop exit). Otherwise, return true.
355358
///
356-
/// In addition, if \p EstimatedLoopInvocationWeight, set the branch weight
357-
/// metadata of \p L to reflect that \p L has an estimated
358-
/// \p EstimatedTripCount iterations and has \c *EstimatedLoopInvocationWeight
359-
/// exit weight through the loop's latch.
359+
/// In addition, if \p EstimatedLoopInvocationWeight:
360+
/// - Set the branch weight metadata of \p L to reflect that \p L has an
361+
/// estimated \p EstimatedTripCount iterations and has
362+
/// \c *EstimatedLoopInvocationWeight exit weight through the loop's latch.
363+
/// - If \p EstimatedTripCount is zero, zero the branch weights.
360364
///
361365
/// TODO: Eventually, once all passes have migrated away from setting branch
362366
/// weights to indicate estimated trip counts, this function will drop the

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -913,11 +913,26 @@ llvm::getLoopEstimatedTripCount(Loop *L,
913913

914914
// Return the estimated trip count from metadata unless the metadata is
915915
// missing or has no value.
916+
//
917+
// Some passes set llvm.loop.estimated_trip_count to 0. For example, after
918+
// peeling 10 or more iterations from a loop with an estimated trip count of
919+
// 10, llvm.loop.estimated_trip_count becomes 0 on the remaining loop. It
920+
// indicates that, each time execution reaches the peeled iterations,
921+
// execution is estimated to exit them without reaching the remaining loop's
922+
// header.
923+
//
924+
// Even if the probability of reaching a loop's header is low, if it is
925+
// reached, it is the start of an iteration. Consequently, some passes
926+
// historically assume that llvm::getLoopEstimatedTripCount always returns a
927+
// positive count or std::nullopt. Thus, return std::nullopt when
928+
// llvm.loop.estimated_trip_count is 0.
916929
if (auto TC = getOptionalIntLoopAttribute(L, LLVMLoopEstimatedTripCount)) {
917930
LLVM_DEBUG(dbgs() << "getLoopEstimatedTripCount: "
918931
<< LLVMLoopEstimatedTripCount << " metadata has trip "
919-
<< "count of " << *TC << " for " << DbgLoop(L) << "\n");
920-
return TC;
932+
<< "count of " << *TC
933+
<< (*TC == 0 ? " (returning std::nullopt)" : "")
934+
<< " for " << DbgLoop(L) << "\n");
935+
return *TC == 0 ? std::nullopt : std::optional(*TC);
921936
}
922937

923938
// Estimate the trip count from latch branch weights.
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
; Check that an estimated trip count of zero does not crash or otherwise break
2+
; LoopVectorize behavior while it tries to create runtime memory checks inside
3+
; an outer loop.
4+
5+
; RUN: opt -passes=loop-vectorize -S %s | FileCheck %s
6+
7+
target triple = "x86_64-unknown-linux-gnu"
8+
9+
; Look for basic signs that vectorization ran and produced memory checks.
10+
; CHECK: @test(
11+
; CHECK: vector.memcheck:
12+
; CHECK: vector.body:
13+
; CHECK: inner:
14+
15+
define void @test(ptr addrspace(1) %p, i32 %n) {
16+
entry:
17+
br label %outer
18+
outer:
19+
br label %inner
20+
inner:
21+
%i = phi i32 [ %inc, %inner ], [ 0, %outer ]
22+
store i32 0, ptr addrspace(1) %p
23+
%load = load i32, ptr addrspace(1) null
24+
%inc = add i32 %i, 1
25+
%cmp = icmp slt i32 %i, %n
26+
br i1 %cmp, label %inner, label %outer.latch
27+
outer.latch:
28+
br i1 %cmp, label %outer, label %exit, !llvm.loop !0
29+
exit:
30+
ret void
31+
}
32+
33+
!0 = distinct !{!0, !1}
34+
!1 = !{!"llvm.loop.estimated_trip_count", i32 0}

llvm/test/Verifier/llvm.loop.estimated_trip_count.ll

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,24 @@ exit:
3636
; RUN: echo '!1 = !{!"llvm.loop.estimated_trip_count", i16 5}' >> %t
3737
; RUN: %{RUN} GOOD
3838

39-
; i32 value.
39+
; i32 arbitrary value.
4040
; RUN: cp %s %t
4141
; RUN: chmod u+w %t
4242
; RUN: echo '!1 = !{!"llvm.loop.estimated_trip_count", i32 5}' >> %t
4343
; RUN: %{RUN} GOOD
4444

45+
; i32 boundary value of 1.
46+
; RUN: cp %s %t
47+
; RUN: chmod u+w %t
48+
; RUN: echo '!1 = !{!"llvm.loop.estimated_trip_count", i32 1}' >> %t
49+
; RUN: %{RUN} GOOD
50+
51+
; i32 boundary value of 0.
52+
; RUN: cp %s %t
53+
; RUN: chmod u+w %t
54+
; RUN: echo '!1 = !{!"llvm.loop.estimated_trip_count", i32 0}' >> %t
55+
; RUN: %{RUN} GOOD
56+
4557
; i64 value.
4658
; RUN: cp %s %t
4759
; RUN: chmod u+w %t

llvm/unittests/Transforms/Utils/LoopUtilsTest.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "llvm/AsmParser/Parser.h"
1515
#include "llvm/IR/Dominators.h"
1616
#include "llvm/IR/Module.h"
17+
#include "llvm/IR/ProfDataUtils.h"
1718
#include "llvm/Support/SourceMgr.h"
1819
#include "gtest/gtest.h"
1920

@@ -195,3 +196,83 @@ TEST(LoopUtils, nestedLoopSharedLatchEstimatedTripCount) {
195196
EXPECT_EQ(getLoopEstimatedTripCount(Outer), std::nullopt);
196197
});
197198
}
199+
200+
// {get,set}LoopEstimatedTripCount implement special handling of zero.
201+
TEST(LoopUtils, zeroEstimatedTripCount) {
202+
LLVMContext C;
203+
const char *IR =
204+
"define void @foo(i1 %c) {\n"
205+
"entry:\n"
206+
" br label %loop0\n"
207+
"loop0:\n"
208+
" br i1 %c, label %loop0, label %loop1\n"
209+
"loop1:\n"
210+
" br i1 %c, label %loop1, label %loop2, !llvm.loop !1\n"
211+
"loop2:\n"
212+
" br i1 %c, label %loop2, label %exit, !prof !5, !llvm.loop !2\n"
213+
"exit:\n"
214+
" ret void\n"
215+
"}\n"
216+
"!1 = distinct !{!1, !3}\n"
217+
"!2 = distinct !{!2, !3, !4}\n"
218+
"!3 = !{!\"foo\", i32 5}\n"
219+
"!4 = !{!\"llvm.loop.estimated_trip_count\", i32 10}\n"
220+
"!5 = !{!\"branch_weights\", i32 1, i32 9}\n"
221+
"\n";
222+
223+
// With EstimatedLoopInvocationWeight, setLoopEstimatedTripCount sets branch
224+
// weights and llvm.loop.estimated_trip_count all to 0, so
225+
// getLoopEstimatedTripCount returns std::nullopt. It does not touch other
226+
// loop metadata, if any.
227+
std::unique_ptr<Module> M = parseIR(C, IR);
228+
run(*M, "foo",
229+
[&](Function &F, DominatorTree &DT, ScalarEvolution &SE, LoopInfo &LI) {
230+
assert(LI.end() - LI.begin() == 3 && "Expected three loops");
231+
for (Loop *L : LI) {
232+
Instruction &LatchBranch = *L->getLoopLatch()->getTerminator();
233+
std::optional<int> Foo = getOptionalIntLoopAttribute(L, "foo");
234+
235+
EXPECT_EQ(setLoopEstimatedTripCount(
236+
L, 0, /*EstimatedLoopInvocationWeight=*/1),
237+
true);
238+
239+
SmallVector<uint32_t, 2> Weights;
240+
EXPECT_EQ(extractBranchWeights(LatchBranch, Weights), true);
241+
EXPECT_EQ(Weights[0], 0u);
242+
EXPECT_EQ(Weights[1], 0u);
243+
EXPECT_EQ(getOptionalIntLoopAttribute(L, "foo"), Foo);
244+
EXPECT_EQ(getOptionalIntLoopAttribute(L, LLVMLoopEstimatedTripCount),
245+
0);
246+
EXPECT_EQ(getLoopEstimatedTripCount(L), std::nullopt);
247+
}
248+
});
249+
250+
// Without EstimatedLoopInvocationWeight, setLoopEstimatedTripCount sets
251+
// llvm.loop.estimated_trip_count to 0, so getLoopEstimatedTripCount returns
252+
// std::nullopt. It does not touch branch weights or other loop metadata, if
253+
// any.
254+
M = parseIR(C, IR);
255+
run(*M, "foo",
256+
[&](Function &F, DominatorTree &DT, ScalarEvolution &SE, LoopInfo &LI) {
257+
assert(LI.end() - LI.begin() == 3 && "Expected three loops");
258+
for (Loop *L : LI) {
259+
Instruction &LatchBranch = *L->getLoopLatch()->getTerminator();
260+
std::optional<int> Foo = getOptionalIntLoopAttribute(L, "foo");
261+
SmallVector<uint32_t, 2> WeightsOld;
262+
bool HasWeights = extractBranchWeights(LatchBranch, WeightsOld);
263+
264+
EXPECT_EQ(setLoopEstimatedTripCount(L, 0), true);
265+
266+
SmallVector<uint32_t, 2> WeightsNew;
267+
EXPECT_EQ(extractBranchWeights(LatchBranch, WeightsNew), HasWeights);
268+
if (HasWeights) {
269+
EXPECT_EQ(WeightsNew[0], WeightsOld[0]);
270+
EXPECT_EQ(WeightsNew[1], WeightsOld[1]);
271+
}
272+
EXPECT_EQ(getOptionalIntLoopAttribute(L, "foo"), Foo);
273+
EXPECT_EQ(getOptionalIntLoopAttribute(L, LLVMLoopEstimatedTripCount),
274+
0);
275+
EXPECT_EQ(getLoopEstimatedTripCount(L), std::nullopt);
276+
}
277+
});
278+
}

0 commit comments

Comments
 (0)