Skip to content

Commit 4233ab6

Browse files
committed
[JTS] Propagate profile info
1 parent 6f7ecae commit 4233ab6

File tree

1 file changed

+43
-3
lines changed

1 file changed

+43
-3
lines changed

llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,23 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "llvm/Transforms/Scalar/JumpTableToSwitch.h"
10+
#include "llvm/ADT/DenseSet.h"
11+
#include "llvm/ADT/STLExtras.h"
12+
#include "llvm/ADT/SmallSet.h"
1013
#include "llvm/ADT/SmallVector.h"
1114
#include "llvm/Analysis/ConstantFolding.h"
1215
#include "llvm/Analysis/DomTreeUpdater.h"
1316
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
1417
#include "llvm/Analysis/PostDominators.h"
1518
#include "llvm/IR/IRBuilder.h"
19+
#include "llvm/IR/LLVMContext.h"
20+
#include "llvm/IR/ProfDataUtils.h"
21+
#include "llvm/ProfileData/InstrProf.h"
1622
#include "llvm/Support/CommandLine.h"
23+
#include "llvm/Support/Error.h"
24+
#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
1725
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
26+
#include <limits>
1827

1928
using namespace llvm;
2029

@@ -92,7 +101,8 @@ static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP,
92101

93102
static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
94103
DomTreeUpdater &DTU,
95-
OptimizationRemarkEmitter &ORE) {
104+
OptimizationRemarkEmitter &ORE,
105+
const InstrProfSymtab &Symtab) {
96106
const bool IsVoid = CB->getType() == Type::getVoidTy(CB->getContext());
97107

98108
SmallVector<DominatorTree::UpdateType, 8> DTUpdates;
@@ -115,7 +125,32 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
115125
IRBuilder<> BuilderTail(CB);
116126
PHINode *PHI =
117127
IsVoid ? nullptr : BuilderTail.CreatePHI(CB->getType(), JT.Funcs.size());
118-
128+
const auto *ProfMD = CB->getMetadata(LLVMContext::MD_prof);
129+
130+
if (isValueProfileMD(ProfMD)) {
131+
SmallVector<uint64_t> BranchWeights;
132+
BranchWeights.reserve(JT.Funcs.size());
133+
uint64_t TotalCount = 0;
134+
auto Targets = getValueProfDataFromInst(
135+
*CB, InstrProfValueKind::IPVK_IndirectCallTarget,
136+
std::numeric_limits<uint32_t>::max(), TotalCount);
137+
DenseMap<const Function *, uint64_t> FToC;
138+
for (const auto &[G, C] : Targets)
139+
if (const auto *F = Symtab.getFunction(G)) {
140+
auto It = FToC.insert({F, C});
141+
if (!It.second)
142+
It.second += C; // Unexpected, but likely the right way forward
143+
}
144+
// In the same order as the entries in JT.Funcs, which is the order in which
145+
// we'll add cases to the switch statement, get the branch weight of the
146+
// corresponding indirect call target, or 0 if we didn't have a profile for
147+
// it.
148+
for (const auto *F : JT.Funcs)
149+
BranchWeights.push_back(FToC.lookup_or(F, 0U));
150+
if (!BranchWeights.empty())
151+
setProfMetadata(F.getParent(), Switch, BranchWeights,
152+
*llvm::max_element(BranchWeights));
153+
}
119154
for (auto [Index, Func] : llvm::enumerate(JT.Funcs)) {
120155
BasicBlock *B = BasicBlock::Create(Func->getContext(),
121156
"call." + Twine(Index), &F, Tail);
@@ -150,6 +185,11 @@ PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
150185
PostDominatorTree *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F);
151186
DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);
152187
bool Changed = false;
188+
InstrProfSymtab Symtab;
189+
if (auto E = Symtab.create(*F.getParent()))
190+
F.getContext().emitError(
191+
"Could not create indirect call table, likely corrupted IR" +
192+
toString(std::move(E)));
153193
for (BasicBlock &BB : make_early_inc_range(F)) {
154194
BasicBlock *CurrentBB = &BB;
155195
while (CurrentBB) {
@@ -170,7 +210,7 @@ PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
170210
std::optional<JumpTableTy> JumpTable = parseJumpTable(GEP, PtrTy);
171211
if (!JumpTable)
172212
continue;
173-
SplittedOutTail = expandToSwitch(Call, *JumpTable, DTU, ORE);
213+
SplittedOutTail = expandToSwitch(Call, *JumpTable, DTU, ORE, Symtab);
174214
Changed = true;
175215
break;
176216
}

0 commit comments

Comments
 (0)