Skip to content

[NFC][PGO] Factor downscaling of branch weights out of Instrumentation into ProfileData #153735

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions llvm/include/llvm/IR/ProfDataUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,33 @@ LLVM_ABI bool extractProfTotalWeight(const Instruction &I,
LLVM_ABI void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
bool IsExpected);

/// downscale the given weights preserving the ratio. If the maximum value is
/// not already known and not provided via \param KnownMaxCount , it will be
/// obtained from \param Weights.
LLVM_ABI SmallVector<uint32_t>
downscaleWeights(ArrayRef<uint64_t> Weights,
std::optional<uint64_t> KnownMaxCount = std::nullopt);

/// Calculate what to divide by to scale counts.
///
/// Given the maximum count, calculate a divisor that will scale all the
/// weights to strictly less than std::numeric_limits<uint32_t>::max().
inline uint64_t calculateCountScale(uint64_t MaxCount) {
return MaxCount < std::numeric_limits<uint32_t>::max()
? 1
: MaxCount / std::numeric_limits<uint32_t>::max() + 1;
}

/// Scale an individual branch count.
///
/// Scale a 64-bit weight down to 32-bits using \c Scale.
///
inline uint32_t scaleBranchCount(uint64_t Count, uint64_t Scale) {
uint64_t Scaled = Count / Scale;
assert(Scaled <= std::numeric_limits<uint32_t>::max() && "overflow 32-bits");
return Scaled;
}

/// Specify that the branch weights for this terminator cannot be known at
/// compile time. This should only be called by passes, and never as a default
/// behavior in e.g. MDBuilder. The goal is to use this info to validate passes
Expand Down
20 changes: 0 additions & 20 deletions llvm/include/llvm/Transforms/Utils/Instrumentation.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,26 +169,6 @@ struct SanitizerCoverageOptions {
SanitizerCoverageOptions() = default;
};

/// Calculate what to divide by to scale counts.
///
/// Given the maximum count, calculate a divisor that will scale all the
/// weights to strictly less than std::numeric_limits<uint32_t>::max().
static inline uint64_t calculateCountScale(uint64_t MaxCount) {
return MaxCount < std::numeric_limits<uint32_t>::max()
? 1
: MaxCount / std::numeric_limits<uint32_t>::max() + 1;
}

/// Scale an individual branch count.
///
/// Scale a 64-bit weight down to 32-bits using \c Scale.
///
static inline uint32_t scaleBranchCount(uint64_t Count, uint64_t Scale) {
uint64_t Scaled = Count / Scale;
assert(Scaled <= std::numeric_limits<uint32_t>::max() && "overflow 32-bits");
return Scaled;
}

// Use to ensure the inserted instrumentation has a DebugLocation; if none is
// attached to the source instruction, try to use a DILocation with offset 0
// scoped to surrounding function (if it has a DebugLocation).
Expand Down
12 changes: 12 additions & 0 deletions llvm/lib/IR/ProfDataUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,18 @@ void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
I.setMetadata(LLVMContext::MD_prof, BranchWeights);
}

SmallVector<uint32_t> downscaleWeights(ArrayRef<uint64_t> Weights,
std::optional<uint64_t> KnownMaxCount) {
uint64_t MaxCount = KnownMaxCount.has_value() ? KnownMaxCount.value()
: *llvm::max_element(Weights);
assert(MaxCount > 0 && "Bad max count");
uint64_t Scale = calculateCountScale(MaxCount);
SmallVector<unsigned, 4> DownscaledWeights;
for (const auto &ECI : Weights)
DownscaledWeights.push_back(scaleBranchCount(ECI, Scale));
return DownscaledWeights;
}

void scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
assert(T != 0 && "Caller should guarantee");
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);
Expand Down
8 changes: 2 additions & 6 deletions llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2409,11 +2409,7 @@ static std::string getSimpleNodeName(const BasicBlock *Node) {

void llvm::setProfMetadata(Instruction *TI, ArrayRef<uint64_t> EdgeCounts,
uint64_t MaxCount) {
assert(MaxCount > 0 && "Bad max count");
uint64_t Scale = calculateCountScale(MaxCount);
SmallVector<unsigned, 4> Weights;
for (const auto &ECI : EdgeCounts)
Weights.push_back(scaleBranchCount(ECI, Scale));
auto Weights = downscaleWeights(EdgeCounts, MaxCount);

LLVM_DEBUG(dbgs() << "Weight is: "; for (const auto &W
: Weights) {
Expand All @@ -2434,7 +2430,7 @@ void llvm::setProfMetadata(Instruction *TI, ArrayRef<uint64_t> EdgeCounts,
uint64_t TotalCount =
std::accumulate(EdgeCounts.begin(), EdgeCounts.end(), (uint64_t)0,
[](uint64_t c1, uint64_t c2) { return c1 + c2; });
Scale = calculateCountScale(WSum);
uint64_t Scale = calculateCountScale(WSum);
BranchProbability BP(scaleBranchCount(Weights[0], Scale),
scaleBranchCount(WSum, Scale));
std::string BranchProbStr;
Expand Down
Loading