7
7
// ===----------------------------------------------------------------------===//
8
8
9
9
#include " llvm/Transforms/Scalar/JumpTableToSwitch.h"
10
+ #include " llvm/ADT/DenseSet.h"
11
+ #include " llvm/ADT/STLExtras.h"
12
+ #include " llvm/ADT/SmallSet.h"
10
13
#include " llvm/ADT/SmallVector.h"
11
14
#include " llvm/Analysis/ConstantFolding.h"
12
15
#include " llvm/Analysis/DomTreeUpdater.h"
13
16
#include " llvm/Analysis/OptimizationRemarkEmitter.h"
14
17
#include " llvm/Analysis/PostDominators.h"
15
18
#include " llvm/IR/IRBuilder.h"
19
+ #include " llvm/IR/LLVMContext.h"
20
+ #include " llvm/IR/ProfDataUtils.h"
21
+ #include " llvm/ProfileData/InstrProf.h"
16
22
#include " llvm/Support/CommandLine.h"
23
+ #include " llvm/Support/Error.h"
24
+ #include " llvm/Transforms/Instrumentation/PGOInstrumentation.h"
17
25
#include " llvm/Transforms/Utils/BasicBlockUtils.h"
26
+ #include < limits>
18
27
19
28
using namespace llvm ;
20
29
@@ -92,7 +101,8 @@ static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP,
92
101
93
102
static BasicBlock *expandToSwitch (CallBase *CB, const JumpTableTy &JT,
94
103
DomTreeUpdater &DTU,
95
- OptimizationRemarkEmitter &ORE) {
104
+ OptimizationRemarkEmitter &ORE,
105
+ const InstrProfSymtab &Symtab) {
96
106
const bool IsVoid = CB->getType () == Type::getVoidTy (CB->getContext ());
97
107
98
108
SmallVector<DominatorTree::UpdateType, 8 > DTUpdates;
@@ -115,7 +125,32 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
115
125
IRBuilder<> BuilderTail (CB);
116
126
PHINode *PHI =
117
127
IsVoid ? nullptr : BuilderTail.CreatePHI (CB->getType (), JT.Funcs .size ());
118
-
128
+ const auto *ProfMD = CB->getMetadata (LLVMContext::MD_prof);
129
+
130
+ if (isValueProfileMD (ProfMD)) {
131
+ SmallVector<uint64_t > BranchWeights;
132
+ BranchWeights.reserve (JT.Funcs .size ());
133
+ uint64_t TotalCount = 0 ;
134
+ auto Targets = getValueProfDataFromInst (
135
+ *CB, InstrProfValueKind::IPVK_IndirectCallTarget,
136
+ std::numeric_limits<uint32_t >::max (), TotalCount);
137
+ DenseMap<const Function *, uint64_t > FToC;
138
+ for (const auto &[G, C] : Targets)
139
+ if (const auto *F = Symtab.getFunction (G)) {
140
+ auto It = FToC.insert ({F, C});
141
+ if (!It.second )
142
+ It.second += C; // Unexpected, but likely the right way forward
143
+ }
144
+ // In the same order as the entries in JT.Funcs, which is the order in which
145
+ // we'll add cases to the switch statement, get the branch weight of the
146
+ // corresponding indirect call target, or 0 if we didn't have a profile for
147
+ // it.
148
+ for (const auto *F : JT.Funcs )
149
+ BranchWeights.push_back (FToC.lookup_or (F, 0U ));
150
+ if (!BranchWeights.empty ())
151
+ setProfMetadata (F.getParent (), Switch, BranchWeights,
152
+ *llvm::max_element (BranchWeights));
153
+ }
119
154
for (auto [Index, Func] : llvm::enumerate (JT.Funcs )) {
120
155
BasicBlock *B = BasicBlock::Create (Func->getContext (),
121
156
" call." + Twine (Index), &F, Tail);
@@ -150,6 +185,11 @@ PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
150
185
PostDominatorTree *PDT = AM.getCachedResult <PostDominatorTreeAnalysis>(F);
151
186
DomTreeUpdater DTU (DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);
152
187
bool Changed = false ;
188
+ InstrProfSymtab Symtab;
189
+ if (auto E = Symtab.create (*F.getParent ()))
190
+ F.getContext ().emitError (
191
+ " Could not create indirect call table, likely corrupted IR" +
192
+ toString (std::move (E)));
153
193
for (BasicBlock &BB : make_early_inc_range (F)) {
154
194
BasicBlock *CurrentBB = &BB;
155
195
while (CurrentBB) {
@@ -170,7 +210,7 @@ PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
170
210
std::optional<JumpTableTy> JumpTable = parseJumpTable (GEP, PtrTy);
171
211
if (!JumpTable)
172
212
continue ;
173
- SplittedOutTail = expandToSwitch (Call, *JumpTable, DTU, ORE);
213
+ SplittedOutTail = expandToSwitch (Call, *JumpTable, DTU, ORE, Symtab );
174
214
Changed = true ;
175
215
break ;
176
216
}
0 commit comments