Skip to content

Commit 3b4775d

Browse files
authored
[NFC][PGO] Factor downscaling of branch weights out of Instrumentation into ProfileData (#153735)
The logic isn’t instrumentation-specific, and the refactoring allows users avoid a dependency on `Instrumentation` and just take one on `ProfileData`​ (which a fairly low-level dependency)
1 parent 0923aaf commit 3b4775d

File tree

4 files changed

+41
-26
lines changed

4 files changed

+41
-26
lines changed

llvm/include/llvm/IR/ProfDataUtils.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,33 @@ LLVM_ABI bool extractProfTotalWeight(const Instruction &I,
144144
LLVM_ABI void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
145145
bool IsExpected);
146146

147+
/// downscale the given weights preserving the ratio. If the maximum value is
148+
/// not already known and not provided via \param KnownMaxCount , it will be
149+
/// obtained from \param Weights.
150+
LLVM_ABI SmallVector<uint32_t>
151+
downscaleWeights(ArrayRef<uint64_t> Weights,
152+
std::optional<uint64_t> KnownMaxCount = std::nullopt);
153+
154+
/// Calculate what to divide by to scale counts.
155+
///
156+
/// Given the maximum count, calculate a divisor that will scale all the
157+
/// weights to strictly less than std::numeric_limits<uint32_t>::max().
158+
inline uint64_t calculateCountScale(uint64_t MaxCount) {
159+
return MaxCount < std::numeric_limits<uint32_t>::max()
160+
? 1
161+
: MaxCount / std::numeric_limits<uint32_t>::max() + 1;
162+
}
163+
164+
/// Scale an individual branch count.
165+
///
166+
/// Scale a 64-bit weight down to 32-bits using \c Scale.
167+
///
168+
inline uint32_t scaleBranchCount(uint64_t Count, uint64_t Scale) {
169+
uint64_t Scaled = Count / Scale;
170+
assert(Scaled <= std::numeric_limits<uint32_t>::max() && "overflow 32-bits");
171+
return Scaled;
172+
}
173+
147174
/// Specify that the branch weights for this terminator cannot be known at
148175
/// compile time. This should only be called by passes, and never as a default
149176
/// behavior in e.g. MDBuilder. The goal is to use this info to validate passes

llvm/include/llvm/Transforms/Utils/Instrumentation.h

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -169,26 +169,6 @@ struct SanitizerCoverageOptions {
169169
SanitizerCoverageOptions() = default;
170170
};
171171

172-
/// Calculate what to divide by to scale counts.
173-
///
174-
/// Given the maximum count, calculate a divisor that will scale all the
175-
/// weights to strictly less than std::numeric_limits<uint32_t>::max().
176-
static inline uint64_t calculateCountScale(uint64_t MaxCount) {
177-
return MaxCount < std::numeric_limits<uint32_t>::max()
178-
? 1
179-
: MaxCount / std::numeric_limits<uint32_t>::max() + 1;
180-
}
181-
182-
/// Scale an individual branch count.
183-
///
184-
/// Scale a 64-bit weight down to 32-bits using \c Scale.
185-
///
186-
static inline uint32_t scaleBranchCount(uint64_t Count, uint64_t Scale) {
187-
uint64_t Scaled = Count / Scale;
188-
assert(Scaled <= std::numeric_limits<uint32_t>::max() && "overflow 32-bits");
189-
return Scaled;
190-
}
191-
192172
// Use to ensure the inserted instrumentation has a DebugLocation; if none is
193173
// attached to the source instruction, try to use a DILocation with offset 0
194174
// scoped to surrounding function (if it has a DebugLocation).

llvm/lib/IR/ProfDataUtils.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,18 @@ void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
270270
I.setMetadata(LLVMContext::MD_prof, BranchWeights);
271271
}
272272

273+
SmallVector<uint32_t> downscaleWeights(ArrayRef<uint64_t> Weights,
274+
std::optional<uint64_t> KnownMaxCount) {
275+
uint64_t MaxCount = KnownMaxCount.has_value() ? KnownMaxCount.value()
276+
: *llvm::max_element(Weights);
277+
assert(MaxCount > 0 && "Bad max count");
278+
uint64_t Scale = calculateCountScale(MaxCount);
279+
SmallVector<unsigned, 4> DownscaledWeights;
280+
for (const auto &ECI : Weights)
281+
DownscaledWeights.push_back(scaleBranchCount(ECI, Scale));
282+
return DownscaledWeights;
283+
}
284+
273285
void scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
274286
assert(T != 0 && "Caller should guarantee");
275287
auto *ProfileData = I.getMetadata(LLVMContext::MD_prof);

llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2409,11 +2409,7 @@ static std::string getSimpleNodeName(const BasicBlock *Node) {
24092409

24102410
void llvm::setProfMetadata(Instruction *TI, ArrayRef<uint64_t> EdgeCounts,
24112411
uint64_t MaxCount) {
2412-
assert(MaxCount > 0 && "Bad max count");
2413-
uint64_t Scale = calculateCountScale(MaxCount);
2414-
SmallVector<unsigned, 4> Weights;
2415-
for (const auto &ECI : EdgeCounts)
2416-
Weights.push_back(scaleBranchCount(ECI, Scale));
2412+
auto Weights = downscaleWeights(EdgeCounts, MaxCount);
24172413

24182414
LLVM_DEBUG(dbgs() << "Weight is: "; for (const auto &W
24192415
: Weights) {
@@ -2434,7 +2430,7 @@ void llvm::setProfMetadata(Instruction *TI, ArrayRef<uint64_t> EdgeCounts,
24342430
uint64_t TotalCount =
24352431
std::accumulate(EdgeCounts.begin(), EdgeCounts.end(), (uint64_t)0,
24362432
[](uint64_t c1, uint64_t c2) { return c1 + c2; });
2437-
Scale = calculateCountScale(WSum);
2433+
uint64_t Scale = calculateCountScale(WSum);
24382434
BranchProbability BP(scaleBranchCount(Weights[0], Scale),
24392435
scaleBranchCount(WSum, Scale));
24402436
std::string BranchProbStr;

0 commit comments

Comments
 (0)