|
12 | 12 |
|
13 | 13 | #include "llvm/IR/ProfDataUtils.h"
|
14 | 14 |
|
| 15 | +#include "llvm/ADT/STLExtras.h" |
15 | 16 | #include "llvm/ADT/SmallVector.h"
|
16 | 17 | #include "llvm/IR/Constants.h"
|
17 | 18 | #include "llvm/IR/Function.h"
|
18 | 19 | #include "llvm/IR/Instructions.h"
|
19 | 20 | #include "llvm/IR/LLVMContext.h"
|
20 | 21 | #include "llvm/IR/MDBuilder.h"
|
21 | 22 | #include "llvm/IR/Metadata.h"
|
| 23 | +#include "llvm/Support/CommandLine.h" |
22 | 24 |
|
23 | 25 | using namespace llvm;
|
24 | 26 |
|
@@ -84,10 +86,31 @@ static void extractFromBranchWeightMD(const MDNode *ProfileData,
|
84 | 86 | }
|
85 | 87 | }
|
86 | 88 |
|
| 89 | +/// Push the weights right to fit in uint32_t. |
| 90 | +static SmallVector<uint32_t> fitWeights(ArrayRef<uint64_t> Weights) { |
| 91 | + SmallVector<uint32_t> Ret; |
| 92 | + Ret.reserve(Weights.size()); |
| 93 | + uint64_t Max = *llvm::max_element(Weights); |
| 94 | + if (Max > UINT_MAX) { |
| 95 | + unsigned Offset = 32 - llvm::countl_zero(Max); |
| 96 | + for (const uint64_t &Value : Weights) |
| 97 | + Ret.push_back(static_cast<uint32_t>(Value >> Offset)); |
| 98 | + } else { |
| 99 | + append_range(Ret, Weights); |
| 100 | + } |
| 101 | + return Ret; |
| 102 | +} |
| 103 | + |
87 | 104 | } // namespace
|
88 | 105 |
|
89 | 106 | namespace llvm {
|
90 |
| - |
| 107 | +cl::opt<bool> ElideAllZeroBranchWeights("elide-all-zero-branch-weights", |
| 108 | +#if defined(LLVM_ENABLE_PROFCHECK) |
| 109 | + cl::init(false) |
| 110 | +#else |
| 111 | + cl::init(true) |
| 112 | +#endif |
| 113 | +); |
91 | 114 | const char *MDProfLabels::BranchWeights = "branch_weights";
|
92 | 115 | const char *MDProfLabels::ExpectedBranchWeights = "expected";
|
93 | 116 | const char *MDProfLabels::ValueProfile = "VP";
|
@@ -282,12 +305,23 @@ bool hasExplicitlyUnknownBranchWeights(const Instruction &I) {
|
282 | 305 | }
|
283 | 306 |
|
284 | 307 | void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
|
285 |
| - bool IsExpected) { |
| 308 | + bool IsExpected, bool ElideAllZero) { |
| 309 | + if ((ElideAllZeroBranchWeights && ElideAllZero) && |
| 310 | + llvm::all_of(Weights, [](uint32_t V) { return V == 0; })) { |
| 311 | + I.setMetadata(LLVMContext::MD_prof, nullptr); |
| 312 | + return; |
| 313 | + } |
| 314 | + |
286 | 315 | MDBuilder MDB(I.getContext());
|
287 | 316 | MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected);
|
288 | 317 | I.setMetadata(LLVMContext::MD_prof, BranchWeights);
|
289 | 318 | }
|
290 | 319 |
|
| 320 | +void setFittedBranchWeights(Instruction &I, ArrayRef<uint64_t> Weights, |
| 321 | + bool IsExpected, bool ElideAllZero) { |
| 322 | + setBranchWeights(I, fitWeights(Weights), IsExpected, ElideAllZero); |
| 323 | +} |
| 324 | + |
291 | 325 | SmallVector<uint32_t> downscaleWeights(ArrayRef<uint64_t> Weights,
|
292 | 326 | std::optional<uint64_t> KnownMaxCount) {
|
293 | 327 | uint64_t MaxCount = KnownMaxCount.has_value() ? KnownMaxCount.value()
|
|
0 commit comments