Skip to content

Commit 240b73e

Browse files
authored
[SimplifyCFG][PGO] Reuse existing setBranchWeights (#160629)
The main difference between SimplifyCFG's `setBranchWeights`​ and the ProfDataUtils' is that the former doesn't propagate all-zero weights. That seems like a sensible thing to do, so updated the latter accordingly, and added a flag to control the behavior. Also moved to ProfDataUtils the logic fitting 64-bit weights to 32-bit. As side-effect, this fixes some profcheck failures.
1 parent 8907b6d commit 240b73e

File tree

8 files changed

+100
-118
lines changed

8 files changed

+100
-118
lines changed

llvm/include/llvm/IR/Instructions.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "llvm/IR/Instruction.h"
3333
#include "llvm/IR/Intrinsics.h"
3434
#include "llvm/IR/OperandTraits.h"
35+
#include "llvm/IR/ProfDataUtils.h"
3536
#include "llvm/IR/Use.h"
3637
#include "llvm/IR/User.h"
3738
#include "llvm/Support/AtomicOrdering.h"
@@ -3536,8 +3537,6 @@ class SwitchInstProfUpdateWrapper {
35363537
bool Changed = false;
35373538

35383539
protected:
3539-
LLVM_ABI MDNode *buildProfBranchWeightsMD();
3540-
35413540
LLVM_ABI void init();
35423541

35433542
public:
@@ -3549,8 +3548,8 @@ class SwitchInstProfUpdateWrapper {
35493548
SwitchInstProfUpdateWrapper(SwitchInst &SI) : SI(SI) { init(); }
35503549

35513550
~SwitchInstProfUpdateWrapper() {
3552-
if (Changed)
3553-
SI.setMetadata(LLVMContext::MD_prof, buildProfBranchWeightsMD());
3551+
if (Changed && Weights.has_value() && Weights->size() >= 2)
3552+
setBranchWeights(SI, Weights.value(), /*IsExpected=*/false);
35543553
}
35553554

35563555
/// Delegate the call to the underlying SwitchInst::removeCase() and remove

llvm/include/llvm/IR/ProfDataUtils.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,13 @@ LLVM_ABI bool extractProfTotalWeight(const Instruction &I,
145145
/// \param Weights an array of weights to set on instruction I.
146146
/// \param IsExpected were these weights added from an llvm.expect* intrinsic.
147147
LLVM_ABI void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
148-
bool IsExpected);
148+
bool IsExpected, bool ElideAllZero = false);
149+
150+
/// Variant of `setBranchWeights` where the `Weights` will be fit first to
151+
/// uint32_t by shifting right.
152+
LLVM_ABI void setFittedBranchWeights(Instruction &I, ArrayRef<uint64_t> Weights,
153+
bool IsExpected,
154+
bool ElideAllZero = false);
149155

150156
/// downscale the given weights preserving the ratio. If the maximum value is
151157
/// not already known and not provided via \param KnownMaxCount , it will be

llvm/lib/IR/Instructions.cpp

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4141,23 +4141,6 @@ void SwitchInst::growOperands() {
41414141
growHungoffUses(ReservedSpace);
41424142
}
41434143

4144-
MDNode *SwitchInstProfUpdateWrapper::buildProfBranchWeightsMD() {
4145-
assert(Changed && "called only if metadata has changed");
4146-
4147-
if (!Weights)
4148-
return nullptr;
4149-
4150-
assert(SI.getNumSuccessors() == Weights->size() &&
4151-
"num of prof branch_weights must accord with num of successors");
4152-
4153-
bool AllZeroes = all_of(*Weights, [](uint32_t W) { return W == 0; });
4154-
4155-
if (AllZeroes || Weights->size() < 2)
4156-
return nullptr;
4157-
4158-
return MDBuilder(SI.getParent()->getContext()).createBranchWeights(*Weights);
4159-
}
4160-
41614144
void SwitchInstProfUpdateWrapper::init() {
41624145
MDNode *ProfileData = getBranchWeightMDNode(SI);
41634146
if (!ProfileData)

llvm/lib/IR/ProfDataUtils.cpp

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212

1313
#include "llvm/IR/ProfDataUtils.h"
1414

15+
#include "llvm/ADT/STLExtras.h"
1516
#include "llvm/ADT/SmallVector.h"
1617
#include "llvm/IR/Constants.h"
1718
#include "llvm/IR/Function.h"
1819
#include "llvm/IR/Instructions.h"
1920
#include "llvm/IR/LLVMContext.h"
2021
#include "llvm/IR/MDBuilder.h"
2122
#include "llvm/IR/Metadata.h"
23+
#include "llvm/Support/CommandLine.h"
2224

2325
using namespace llvm;
2426

@@ -84,10 +86,31 @@ static void extractFromBranchWeightMD(const MDNode *ProfileData,
8486
}
8587
}
8688

89+
/// Push the weights right to fit in uint32_t.
90+
static SmallVector<uint32_t> fitWeights(ArrayRef<uint64_t> Weights) {
91+
SmallVector<uint32_t> Ret;
92+
Ret.reserve(Weights.size());
93+
uint64_t Max = *llvm::max_element(Weights);
94+
if (Max > UINT_MAX) {
95+
unsigned Offset = 32 - llvm::countl_zero(Max);
96+
for (const uint64_t &Value : Weights)
97+
Ret.push_back(static_cast<uint32_t>(Value >> Offset));
98+
} else {
99+
append_range(Ret, Weights);
100+
}
101+
return Ret;
102+
}
103+
87104
} // namespace
88105

89106
namespace llvm {
90-
107+
cl::opt<bool> ElideAllZeroBranchWeights("elide-all-zero-branch-weights",
108+
#if defined(LLVM_ENABLE_PROFCHECK)
109+
cl::init(false)
110+
#else
111+
cl::init(true)
112+
#endif
113+
);
91114
const char *MDProfLabels::BranchWeights = "branch_weights";
92115
const char *MDProfLabels::ExpectedBranchWeights = "expected";
93116
const char *MDProfLabels::ValueProfile = "VP";
@@ -282,12 +305,23 @@ bool hasExplicitlyUnknownBranchWeights(const Instruction &I) {
282305
}
283306

284307
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
285-
bool IsExpected) {
308+
bool IsExpected, bool ElideAllZero) {
309+
if ((ElideAllZeroBranchWeights && ElideAllZero) &&
310+
llvm::all_of(Weights, [](uint32_t V) { return V == 0; })) {
311+
I.setMetadata(LLVMContext::MD_prof, nullptr);
312+
return;
313+
}
314+
286315
MDBuilder MDB(I.getContext());
287316
MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected);
288317
I.setMetadata(LLVMContext::MD_prof, BranchWeights);
289318
}
290319

320+
void setFittedBranchWeights(Instruction &I, ArrayRef<uint64_t> Weights,
321+
bool IsExpected, bool ElideAllZero) {
322+
setBranchWeights(I, fitWeights(Weights), IsExpected, ElideAllZero);
323+
}
324+
291325
SmallVector<uint32_t> downscaleWeights(ArrayRef<uint64_t> Weights,
292326
std::optional<uint64_t> KnownMaxCount) {
293327
uint64_t MaxCount = KnownMaxCount.has_value() ? KnownMaxCount.value()

llvm/lib/Transforms/IPO/SampleProfile.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1664,8 +1664,9 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
16641664
else if (OverwriteExistingWeights)
16651665
I.setMetadata(LLVMContext::MD_prof, nullptr);
16661666
} else if (!isa<IntrinsicInst>(&I)) {
1667-
setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])},
1668-
/*IsExpected=*/false);
1667+
setBranchWeights(
1668+
I, ArrayRef<uint32_t>{static_cast<uint32_t>(BlockWeights[BB])},
1669+
/*IsExpected=*/false);
16691670
}
16701671
}
16711672
} else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) {
@@ -1676,7 +1677,8 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
16761677
if (cast<CallBase>(I).isIndirectCall()) {
16771678
I.setMetadata(LLVMContext::MD_prof, nullptr);
16781679
} else {
1679-
setBranchWeights(I, {uint32_t(0)}, /*IsExpected=*/false);
1680+
setBranchWeights(I, ArrayRef<uint32_t>{uint32_t(0)},
1681+
/*IsExpected=*/false);
16801682
}
16811683
}
16821684
}

llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -672,8 +672,8 @@ CallBase &llvm::pgo::promoteIndirectCall(CallBase &CB, Function *DirectCallee,
672672
createBranchWeights(CB.getContext(), Count, TotalCount - Count));
673673

674674
if (AttachProfToDirectCall)
675-
setBranchWeights(NewInst, {static_cast<uint32_t>(Count)},
676-
/*IsExpected=*/false);
675+
setFittedBranchWeights(NewInst, {Count},
676+
/*IsExpected=*/false);
677677

678678
using namespace ore;
679679

0 commit comments

Comments
 (0)