diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h index ca56e4aa81575..404875285beae 100644 --- a/llvm/include/llvm/IR/ProfDataUtils.h +++ b/llvm/include/llvm/IR/ProfDataUtils.h @@ -144,6 +144,33 @@ LLVM_ABI bool extractProfTotalWeight(const Instruction &I, LLVM_ABI void setBranchWeights(Instruction &I, ArrayRef Weights, bool IsExpected); +/// downscale the given weights preserving the ratio. If the maximum value is +/// not already known and not provided via \param KnownMaxCount , it will be +/// obtained from \param Weights. +LLVM_ABI SmallVector +downscaleWeights(ArrayRef Weights, + std::optional KnownMaxCount = std::nullopt); + +/// Calculate what to divide by to scale counts. +/// +/// Given the maximum count, calculate a divisor that will scale all the +/// weights to strictly less than std::numeric_limits::max(). +inline uint64_t calculateCountScale(uint64_t MaxCount) { + return MaxCount < std::numeric_limits::max() + ? 1 + : MaxCount / std::numeric_limits::max() + 1; +} + +/// Scale an individual branch count. +/// +/// Scale a 64-bit weight down to 32-bits using \c Scale. +/// +inline uint32_t scaleBranchCount(uint64_t Count, uint64_t Scale) { + uint64_t Scaled = Count / Scale; + assert(Scaled <= std::numeric_limits::max() && "overflow 32-bits"); + return Scaled; +} + /// Specify that the branch weights for this terminator cannot be known at /// compile time. This should only be called by passes, and never as a default /// behavior in e.g. MDBuilder. The goal is to use this info to validate passes diff --git a/llvm/include/llvm/Transforms/Utils/Instrumentation.h b/llvm/include/llvm/Transforms/Utils/Instrumentation.h index 962d9e734a40a..93ab8c693607f 100644 --- a/llvm/include/llvm/Transforms/Utils/Instrumentation.h +++ b/llvm/include/llvm/Transforms/Utils/Instrumentation.h @@ -169,26 +169,6 @@ struct SanitizerCoverageOptions { SanitizerCoverageOptions() = default; }; -/// Calculate what to divide by to scale counts. -/// -/// Given the maximum count, calculate a divisor that will scale all the -/// weights to strictly less than std::numeric_limits::max(). -static inline uint64_t calculateCountScale(uint64_t MaxCount) { - return MaxCount < std::numeric_limits::max() - ? 1 - : MaxCount / std::numeric_limits::max() + 1; -} - -/// Scale an individual branch count. -/// -/// Scale a 64-bit weight down to 32-bits using \c Scale. -/// -static inline uint32_t scaleBranchCount(uint64_t Count, uint64_t Scale) { - uint64_t Scaled = Count / Scale; - assert(Scaled <= std::numeric_limits::max() && "overflow 32-bits"); - return Scaled; -} - // Use to ensure the inserted instrumentation has a DebugLocation; if none is // attached to the source instruction, try to use a DILocation with offset 0 // scoped to surrounding function (if it has a DebugLocation). diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp index b1b5f67689e6d..489fbfef00e4d 100644 --- a/llvm/lib/IR/ProfDataUtils.cpp +++ b/llvm/lib/IR/ProfDataUtils.cpp @@ -270,6 +270,18 @@ void setBranchWeights(Instruction &I, ArrayRef Weights, I.setMetadata(LLVMContext::MD_prof, BranchWeights); } +SmallVector downscaleWeights(ArrayRef Weights, + std::optional KnownMaxCount) { + uint64_t MaxCount = KnownMaxCount.has_value() ? KnownMaxCount.value() + : *llvm::max_element(Weights); + assert(MaxCount > 0 && "Bad max count"); + uint64_t Scale = calculateCountScale(MaxCount); + SmallVector DownscaledWeights; + for (const auto &ECI : Weights) + DownscaledWeights.push_back(scaleBranchCount(ECI, Scale)); + return DownscaledWeights; +} + void scaleProfData(Instruction &I, uint64_t S, uint64_t T) { assert(T != 0 && "Caller should guarantee"); auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); diff --git a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp index e0b22ef94d064..d9e850e7a2bf3 100644 --- a/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp +++ b/llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp @@ -2409,11 +2409,7 @@ static std::string getSimpleNodeName(const BasicBlock *Node) { void llvm::setProfMetadata(Instruction *TI, ArrayRef EdgeCounts, uint64_t MaxCount) { - assert(MaxCount > 0 && "Bad max count"); - uint64_t Scale = calculateCountScale(MaxCount); - SmallVector Weights; - for (const auto &ECI : EdgeCounts) - Weights.push_back(scaleBranchCount(ECI, Scale)); + auto Weights = downscaleWeights(EdgeCounts, MaxCount); LLVM_DEBUG(dbgs() << "Weight is: "; for (const auto &W : Weights) { @@ -2434,7 +2430,7 @@ void llvm::setProfMetadata(Instruction *TI, ArrayRef EdgeCounts, uint64_t TotalCount = std::accumulate(EdgeCounts.begin(), EdgeCounts.end(), (uint64_t)0, [](uint64_t c1, uint64_t c2) { return c1 + c2; }); - Scale = calculateCountScale(WSum); + uint64_t Scale = calculateCountScale(WSum); BranchProbability BP(scaleBranchCount(Weights[0], Scale), scaleBranchCount(WSum, Scale)); std::string BranchProbStr;