Skip to content

Commit 0e3c556

Browse files
authored
[PGO] Add llvm.loop.estimated_trip_count metadata (#152775)
This patch implements the `llvm.loop.estimated_trip_count` metadata discussed in [[RFC] Fix Loop Transformations to Preserve Block Frequencies](https://discourse.llvm.org/t/rfc-fix-loop-transformations-to-preserve-block-frequencies/85785). As the RFC explains, that metadata enables future patches, such as PR #128785, to fix block frequency issues without losing estimated trip counts.
1 parent e08588d commit 0e3c556

File tree

12 files changed

+487
-178
lines changed

12 files changed

+487
-178
lines changed

llvm/docs/LangRef.rst

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7840,6 +7840,54 @@ If a loop was successfully processed by the loop distribution pass,
78407840
this metadata is added (i.e., has been distributed). See
78417841
:ref:`Transformation Metadata <transformation-metadata>` for details.
78427842

7843+
'``llvm.loop.estimated_trip_count``' Metadata
7844+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
7845+
7846+
This metadata records an estimated trip count for the loop. The first operand
7847+
is the string ``llvm.loop.estimated_trip_count``. The second operand is an
7848+
integer constant of type ``i32`` or smaller specifying the estimate. For
7849+
example:
7850+
7851+
.. code-block:: llvm
7852+
7853+
!0 = !{!"llvm.loop.estimated_trip_count", i32 8}
7854+
7855+
Purpose
7856+
"""""""
7857+
7858+
A loop's estimated trip count is an estimate of the average number of loop
7859+
iterations (specifically, the number of times the loop's header executes) each
7860+
time execution reaches the loop. It is usually only an estimate based on, for
7861+
example, profile data. The actual number of iterations might vary widely.
7862+
7863+
The estimated trip count serves as a parameter for various loop transformations
7864+
and typically helps estimate transformation cost. For example, it can help
7865+
determine how many iterations to peel or how aggressively to unroll.
7866+
7867+
Initialization and Maintenance
7868+
""""""""""""""""""""""""""""""
7869+
7870+
Passes should interact with estimated trip counts always via
7871+
``llvm::getLoopEstimatedTripCount`` and ``llvm::setLoopEstimatedTripCount``.
7872+
7873+
When the ``llvm.loop.estimated_trip_count`` metadata is not present on a loop,
7874+
``llvm::getLoopEstimatedTripCount`` estimates the loop's trip count from the
7875+
loop's ``branch_weights`` metadata under the assumption that the latter still
7876+
accurately encodes the program's original profile data. However, as passes
7877+
transform existing loops and create new loops, they must be free to update and
7878+
create ``branch_weights`` metadata in a way that maintains accurate block
7879+
frequencies. Trip counts estimated from this new ``branch_weights`` metadata
7880+
are not necessarily useful to the passes that consume estimated trip counts.
7881+
7882+
For this reason, when a pass transforms or creates loops, the pass should
7883+
separately estimate new trip counts based on the estimated trip counts that
7884+
``llvm::getLoopEstimatedTripCount`` returns at the start of the pass, and the
7885+
pass should record the new estimates by calling
7886+
``llvm::setLoopEstimatedTripCount``, which creates or updates
7887+
``llvm.loop.estimated_trip_count`` metadata. Once this metadata is present on a
7888+
loop, ``llvm::getLoopEstimatedTripCount`` returns its value instead of
7889+
estimating the trip count from the loop's ``branch_weights`` metadata.
7890+
78437891
'``llvm.licm.disable``' Metadata
78447892
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
78457893

llvm/include/llvm/IR/Metadata.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -919,8 +919,8 @@ class MDOperand {
919919

920920
// Check if MDOperand is of type MDString and equals `Str`.
921921
bool equalsStr(StringRef Str) const {
922-
return isa<MDString>(this->get()) &&
923-
cast<MDString>(this->get())->getString() == Str;
922+
return isa_and_nonnull<MDString>(get()) &&
923+
cast<MDString>(get())->getString() == Str;
924924
}
925925

926926
~MDOperand() { untrack(); }

llvm/include/llvm/IR/ProfDataUtils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ struct MDProfLabels {
3030
LLVM_ABI static const char *UnknownBranchWeightsMarker;
3131
};
3232

33+
/// Profile-based loop metadata that should be accessed only by using
34+
/// \c llvm::getLoopEstimatedTripCount and \c llvm::setLoopEstimatedTripCount.
35+
LLVM_ABI extern const char *LLVMLoopEstimatedTripCount;
36+
3337
/// Checks if an Instruction has MD_prof Metadata
3438
LLVM_ABI bool hasProfMD(const Instruction &I);
3539

llvm/include/llvm/Transforms/Utils/LoopUtils.h

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -323,22 +323,48 @@ LLVM_ABI TransformationMode hasLICMVersioningTransformation(const Loop *L);
323323
LLVM_ABI void addStringMetadataToLoop(Loop *TheLoop, const char *MDString,
324324
unsigned V = 0);
325325

326-
/// Returns a loop's estimated trip count based on branch weight metadata.
327-
/// In addition if \p EstimatedLoopInvocationWeight is not null it is
328-
/// initialized with weight of loop's latch leading to the exit.
329-
/// Returns a valid positive trip count, saturated at UINT_MAX, or std::nullopt
330-
/// when a meaningful estimate cannot be made.
326+
/// Return either:
327+
/// - \c std::nullopt, if the implementation is unable to handle the loop form
328+
/// of \p L (e.g., \p L must have a latch block that controls the loop exit).
329+
/// - The value of \c llvm.loop.estimated_trip_count from the loop metadata of
330+
/// \p L, if that metadata is present.
331+
/// - Else, a new estimate of the trip count from the latch branch weights of
332+
/// \p L.
333+
///
334+
/// An estimated trip count is always a valid positive trip count, saturated at
335+
/// \c UINT_MAX.
336+
///
337+
/// In addition, if \p EstimatedLoopInvocationWeight, then either:
338+
/// - Set \c *EstimatedLoopInvocationWeight to the weight of the latch's branch
339+
/// to the loop exit.
340+
/// - Do not set it, and return \c std::nullopt, if the current implementation
341+
/// cannot compute that weight (e.g., if \p L does not have a latch block that
342+
/// controls the loop exit) or the weight is zero (because zero cannot be
343+
/// used to compute new branch weights that reflect the estimated trip count).
344+
///
345+
/// TODO: Eventually, once all passes have migrated away from setting branch
346+
/// weights to indicate estimated trip counts, this function will drop the
347+
/// \p EstimatedLoopInvocationWeight parameter.
331348
LLVM_ABI std::optional<unsigned>
332349
getLoopEstimatedTripCount(Loop *L,
333350
unsigned *EstimatedLoopInvocationWeight = nullptr);
334351

335-
/// Set a loop's branch weight metadata to reflect that loop has \p
336-
/// EstimatedTripCount iterations and \p EstimatedLoopInvocationWeight exits
337-
/// through latch. Returns true if metadata is successfully updated, false
338-
/// otherwise. Note that loop must have a latch block which controls loop exit
339-
/// in order to succeed.
340-
LLVM_ABI bool setLoopEstimatedTripCount(Loop *L, unsigned EstimatedTripCount,
341-
unsigned EstimatedLoopInvocationWeight);
352+
/// Set \c llvm.loop.estimated_trip_count with the value \p EstimatedTripCount
353+
/// in the loop metadata of \p L. Return false if the implementation is unable
354+
/// to handle the loop form of \p L (e.g., \p L must have a latch block that
355+
/// controls the loop exit). Otherwise, return true.
356+
///
357+
/// In addition, if \p EstimatedLoopInvocationWeight, set the branch weight
358+
/// metadata of \p L to reflect that \p L has an estimated
359+
/// \p EstimatedTripCount iterations and has \c *EstimatedLoopInvocationWeight
360+
/// exit weight through the loop's latch.
361+
///
362+
/// TODO: Eventually, once all passes have migrated away from setting branch
363+
/// weights to indicate estimated trip counts, this function will drop the
364+
/// \p EstimatedLoopInvocationWeight parameter.
365+
LLVM_ABI bool setLoopEstimatedTripCount(
366+
Loop *L, unsigned EstimatedTripCount,
367+
std::optional<unsigned> EstimatedLoopInvocationWeight = std::nullopt);
342368

343369
/// Check inner loop (L) backedge count is known to be invariant on all
344370
/// iterations of its outer loop. If the loop has no parent, this is trivially

llvm/lib/IR/ProfDataUtils.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ const char *MDProfLabels::FunctionEntryCount = "function_entry_count";
9595
const char *MDProfLabels::SyntheticFunctionEntryCount =
9696
"synthetic_function_entry_count";
9797
const char *MDProfLabels::UnknownBranchWeightsMarker = "unknown";
98+
const char *LLVMLoopEstimatedTripCount = "llvm.loop.estimated_trip_count";
9899

99100
bool hasProfMD(const Instruction &I) {
100101
return I.hasMetadata(LLVMContext::MD_prof);

llvm/lib/IR/Verifier.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,6 +1076,18 @@ void Verifier::visitMDNode(const MDNode &MD, AreDebugLocsAllowed AllowLocs) {
10761076
}
10771077
}
10781078

1079+
// Check llvm.loop.estimated_trip_count.
1080+
if (MD.getNumOperands() > 0 &&
1081+
MD.getOperand(0).equalsStr(LLVMLoopEstimatedTripCount)) {
1082+
Check(MD.getNumOperands() == 2, "Expected two operands", &MD);
1083+
auto *Count = dyn_cast_or_null<ConstantAsMetadata>(MD.getOperand(1));
1084+
Check(Count && Count->getType()->isIntegerTy() &&
1085+
cast<IntegerType>(Count->getType())->getBitWidth() <= 32,
1086+
"Expected second operand to be an integer constant of type i32 or "
1087+
"smaller",
1088+
&MD);
1089+
}
1090+
10791091
// Check these last, so we diagnose problems in operands first.
10801092
Check(!MD.isTemporary(), "Expected no forward declarations!", &MD);
10811093
Check(MD.isResolved(), "All nodes should be resolved!", &MD);

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 114 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -804,26 +804,51 @@ static BranchInst *getExpectedExitLoopLatchBranch(Loop *L) {
804804
return LatchBR;
805805
}
806806

807-
/// Return the estimated trip count for any exiting branch which dominates
808-
/// the loop latch.
809-
static std::optional<unsigned> getEstimatedTripCount(BranchInst *ExitingBranch,
810-
Loop *L,
811-
uint64_t &OrigExitWeight) {
807+
struct DbgLoop {
808+
const Loop *L;
809+
explicit DbgLoop(const Loop *L) : L(L) {}
810+
};
811+
812+
#ifndef NDEBUG
813+
static inline raw_ostream &operator<<(raw_ostream &OS, DbgLoop D) {
814+
OS << "function ";
815+
D.L->getHeader()->getParent()->printAsOperand(OS, /*PrintType=*/false);
816+
return OS << " " << *D.L;
817+
}
818+
#endif // NDEBUG
819+
820+
static std::optional<unsigned> estimateLoopTripCount(Loop *L) {
821+
// Currently we take the estimate exit count only from the loop latch,
822+
// ignoring other exiting blocks. This can overestimate the trip count
823+
// if we exit through another exit, but can never underestimate it.
824+
// TODO: incorporate information from other exits
825+
BranchInst *ExitingBranch = getExpectedExitLoopLatchBranch(L);
826+
if (!ExitingBranch) {
827+
LLVM_DEBUG(dbgs() << "estimateLoopTripCount: Failed to find exiting "
828+
<< "latch branch of required form in " << DbgLoop(L)
829+
<< "\n");
830+
return std::nullopt;
831+
}
832+
812833
// To estimate the number of times the loop body was executed, we want to
813834
// know the number of times the backedge was taken, vs. the number of times
814835
// we exited the loop.
815836
uint64_t LoopWeight, ExitWeight;
816-
if (!extractBranchWeights(*ExitingBranch, LoopWeight, ExitWeight))
837+
if (!extractBranchWeights(*ExitingBranch, LoopWeight, ExitWeight)) {
838+
LLVM_DEBUG(dbgs() << "estimateLoopTripCount: Failed to extract branch "
839+
<< "weights for " << DbgLoop(L) << "\n");
817840
return std::nullopt;
841+
}
818842

819843
if (L->contains(ExitingBranch->getSuccessor(1)))
820844
std::swap(LoopWeight, ExitWeight);
821845

822-
if (!ExitWeight)
846+
if (!ExitWeight) {
823847
// Don't have a way to return predicated infinite
848+
LLVM_DEBUG(dbgs() << "estimateLoopTripCount: Failed because of zero exit "
849+
<< "probability for " << DbgLoop(L) << "\n");
824850
return std::nullopt;
825-
826-
OrigExitWeight = ExitWeight;
851+
}
827852

828853
// Estimated exit count is a ratio of the loop weight by the weight of the
829854
// edge exiting the loop, rounded to nearest.
@@ -834,43 +859,102 @@ static std::optional<unsigned> getEstimatedTripCount(BranchInst *ExitingBranch,
834859
return std::numeric_limits<unsigned>::max();
835860

836861
// Estimated trip count is one plus estimated exit count.
837-
return ExitCount + 1;
862+
uint64_t TC = ExitCount + 1;
863+
LLVM_DEBUG(dbgs() << "estimateLoopTripCount: Estimated trip count of " << TC
864+
<< " for " << DbgLoop(L) << "\n");
865+
return TC;
838866
}
839867

840868
std::optional<unsigned>
841869
llvm::getLoopEstimatedTripCount(Loop *L,
842870
unsigned *EstimatedLoopInvocationWeight) {
843-
// Currently we take the estimate exit count only from the loop latch,
844-
// ignoring other exiting blocks. This can overestimate the trip count
845-
// if we exit through another exit, but can never underestimate it.
846-
// TODO: incorporate information from other exits
847-
if (BranchInst *LatchBranch = getExpectedExitLoopLatchBranch(L)) {
848-
uint64_t ExitWeight;
849-
if (std::optional<uint64_t> EstTripCount =
850-
getEstimatedTripCount(LatchBranch, L, ExitWeight)) {
851-
if (EstimatedLoopInvocationWeight)
852-
*EstimatedLoopInvocationWeight = ExitWeight;
853-
return *EstTripCount;
854-
}
871+
// If EstimatedLoopInvocationWeight, we do not support this loop if
872+
// getExpectedExitLoopLatchBranch returns nullptr.
873+
//
874+
// FIXME: Also, this is a stop-gap solution for nested loops. It avoids
875+
// mistaking LLVMLoopEstimatedTripCount metadata to be for an outer loop when
876+
// it was created for an inner loop. The problem is that loop metadata is
877+
// attached to the branch instruction in the loop latch block, but that can be
878+
// shared by the loops. A solution is to attach loop metadata to loop headers
879+
// instead, but that would be a large change to LLVM.
880+
//
881+
// Until that happens, we work around the problem as follows.
882+
// getExpectedExitLoopLatchBranch (which also guards
883+
// setLoopEstimatedTripCount) returns nullptr for a loop unless the loop has
884+
// one latch and that latch has exactly two successors one of which is an exit
885+
// from the loop. If the latch is shared by nested loops, then that condition
886+
// might hold for the inner loop but cannot hold for the outer loop:
887+
// - Because the latch is shared, it must have at least two successors: the
888+
// inner loop header and the outer loop header, which is also an exit for
889+
// the inner loop. That satisifies the condition for the inner loop.
890+
// - To satsify the condition for the outer loop, the latch must have a third
891+
// successor that is an exit for the outer loop. But that violates the
892+
// condition for both loops.
893+
BranchInst *ExitingBranch = getExpectedExitLoopLatchBranch(L);
894+
if (!ExitingBranch)
895+
return std::nullopt;
896+
897+
// If requested, either compute *EstimatedLoopInvocationWeight or return
898+
// nullopt if cannot.
899+
//
900+
// TODO: Eventually, once all passes have migrated away from setting branch
901+
// weights to indicate estimated trip counts, this function will drop the
902+
// EstimatedLoopInvocationWeight parameter.
903+
if (EstimatedLoopInvocationWeight) {
904+
uint64_t LoopWeight = 0, ExitWeight = 0; // Inits expected to be unused.
905+
if (!extractBranchWeights(*ExitingBranch, LoopWeight, ExitWeight))
906+
return std::nullopt;
907+
if (L->contains(ExitingBranch->getSuccessor(1)))
908+
std::swap(LoopWeight, ExitWeight);
909+
if (!ExitWeight)
910+
return std::nullopt;
911+
*EstimatedLoopInvocationWeight = ExitWeight;
855912
}
856-
return std::nullopt;
913+
914+
// Return the estimated trip count from metadata unless the metadata is
915+
// missing or has no value.
916+
if (auto TC = getOptionalIntLoopAttribute(L, LLVMLoopEstimatedTripCount)) {
917+
LLVM_DEBUG(dbgs() << "getLoopEstimatedTripCount: "
918+
<< LLVMLoopEstimatedTripCount << " metadata has trip "
919+
<< "count of " << *TC << " for " << DbgLoop(L) << "\n");
920+
return TC;
921+
}
922+
923+
// Estimate the trip count from latch branch weights.
924+
return estimateLoopTripCount(L);
857925
}
858926

859-
bool llvm::setLoopEstimatedTripCount(Loop *L, unsigned EstimatedTripCount,
860-
unsigned EstimatedloopInvocationWeight) {
861-
// At the moment, we currently support changing the estimate trip count of
862-
// the latch branch only. We could extend this API to manipulate estimated
863-
// trip counts for any exit.
927+
bool llvm::setLoopEstimatedTripCount(
928+
Loop *L, unsigned EstimatedTripCount,
929+
std::optional<unsigned> EstimatedloopInvocationWeight) {
930+
// If EstimatedLoopInvocationWeight, we do not support this loop if
931+
// getExpectedExitLoopLatchBranch returns nullptr.
932+
//
933+
// FIXME: See comments in getLoopEstimatedTripCount for why this is required
934+
// here regardless of EstimatedLoopInvocationWeight.
864935
BranchInst *LatchBranch = getExpectedExitLoopLatchBranch(L);
865936
if (!LatchBranch)
866937
return false;
867938

939+
// Set the metadata.
940+
addStringMetadataToLoop(L, LLVMLoopEstimatedTripCount, EstimatedTripCount);
941+
942+
// At the moment, we currently support changing the estimated trip count in
943+
// the latch branch's branch weights only. We could extend this API to
944+
// manipulate estimated trip counts for any exit.
945+
//
946+
// TODO: Eventually, once all passes have migrated away from setting branch
947+
// weights to indicate estimated trip counts, we will not set branch weights
948+
// here at all.
949+
if (!EstimatedloopInvocationWeight)
950+
return true;
951+
868952
// Calculate taken and exit weights.
869953
unsigned LatchExitWeight = 0;
870954
unsigned BackedgeTakenWeight = 0;
871955

872-
if (EstimatedTripCount > 0) {
873-
LatchExitWeight = EstimatedloopInvocationWeight;
956+
if (EstimatedTripCount != 0) {
957+
LatchExitWeight = *EstimatedloopInvocationWeight;
874958
BackedgeTakenWeight = (EstimatedTripCount - 1) * LatchExitWeight;
875959
}
876960

0 commit comments

Comments
 (0)