7
7
// ===----------------------------------------------------------------------===//
8
8
9
9
#include " llvm/Transforms/Scalar/JumpTableToSwitch.h"
10
+ #include " llvm/ADT/DenseSet.h"
11
+ #include " llvm/ADT/STLExtras.h"
12
+ #include " llvm/ADT/SmallSet.h"
10
13
#include " llvm/ADT/SmallVector.h"
11
14
#include " llvm/Analysis/ConstantFolding.h"
15
+ #include " llvm/Analysis/CtxProfAnalysis.h"
12
16
#include " llvm/Analysis/DomTreeUpdater.h"
13
17
#include " llvm/Analysis/OptimizationRemarkEmitter.h"
14
18
#include " llvm/Analysis/PostDominators.h"
15
19
#include " llvm/IR/IRBuilder.h"
20
+ #include " llvm/IR/LLVMContext.h"
21
+ #include " llvm/IR/ProfDataUtils.h"
22
+ #include " llvm/ProfileData/InstrProf.h"
16
23
#include " llvm/Support/CommandLine.h"
24
+ #include " llvm/Support/Error.h"
25
+ #include " llvm/Transforms/Instrumentation/PGOInstrumentation.h"
17
26
#include " llvm/Transforms/Utils/BasicBlockUtils.h"
27
+ #include < limits>
18
28
19
29
using namespace llvm ;
20
30
@@ -33,6 +43,8 @@ static cl::opt<unsigned> FunctionSizeThreshold(
33
43
" or equal than this threshold." ),
34
44
cl::init(50 ));
35
45
46
+ extern cl::opt<bool > ProfcheckDisableMetadataFixes;
47
+
36
48
#define DEBUG_TYPE " jump-table-to-switch"
37
49
38
50
namespace {
@@ -90,9 +102,11 @@ static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP,
90
102
return JumpTable;
91
103
}
92
104
93
- static BasicBlock *expandToSwitch (CallBase *CB, const JumpTableTy &JT,
94
- DomTreeUpdater &DTU,
95
- OptimizationRemarkEmitter &ORE) {
105
+ static BasicBlock *
106
+ expandToSwitch (CallBase *CB, const JumpTableTy &JT, DomTreeUpdater &DTU,
107
+ OptimizationRemarkEmitter &ORE,
108
+ llvm::function_ref<GlobalValue::GUID(const Function &)>
109
+ GetGuidForFunction) {
96
110
const bool IsVoid = CB->getType () == Type::getVoidTy (CB->getContext ());
97
111
98
112
SmallVector<DominatorTree::UpdateType, 8 > DTUpdates;
@@ -115,7 +129,31 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
115
129
IRBuilder<> BuilderTail (CB);
116
130
PHINode *PHI =
117
131
IsVoid ? nullptr : BuilderTail.CreatePHI (CB->getType (), JT.Funcs .size ());
118
-
132
+ const auto *ProfMD = CB->getMetadata (LLVMContext::MD_prof);
133
+
134
+ SmallVector<uint64_t > BranchWeights;
135
+ DenseMap<GlobalValue::GUID, uint64_t > GuidToCounter;
136
+ const bool HadProfile = isValueProfileMD (ProfMD);
137
+ if (HadProfile) {
138
+ // The assumptions, coming in, are that the functions in JT.Funcs are
139
+ // defined in this module (from parseJumpTable).
140
+ assert (llvm::all_of (
141
+ JT.Funcs , [](const Function *F) { return F && !F->isDeclaration (); }));
142
+ BranchWeights.reserve (JT.Funcs .size () + 1 );
143
+ // The first is the default target, which is the unreachable block created
144
+ // above.
145
+ BranchWeights.push_back (0U );
146
+ uint64_t TotalCount = 0 ;
147
+ auto Targets = getValueProfDataFromInst (
148
+ *CB, InstrProfValueKind::IPVK_IndirectCallTarget,
149
+ std::numeric_limits<uint32_t >::max (), TotalCount);
150
+
151
+ for (const auto &[G, C] : Targets) {
152
+ auto It = GuidToCounter.insert ({G, C});
153
+ assert (It.second );
154
+ (void )It;
155
+ }
156
+ }
119
157
for (auto [Index, Func] : llvm::enumerate (JT.Funcs )) {
120
158
BasicBlock *B = BasicBlock::Create (Func->getContext (),
121
159
" call." + Twine (Index), &F, Tail);
@@ -127,6 +165,11 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
127
165
Call->insertInto (B, B->end ());
128
166
Switch->addCase (
129
167
cast<ConstantInt>(ConstantInt::get (JT.Index ->getType (), Index)), B);
168
+ GlobalValue::GUID FctID = GetGuidForFunction (*Func);
169
+ // It'd be OK to _not_ find target functions in GuidToCounter, e.g. suppose
170
+ // just some of the jump targets are taken (for the given profile).
171
+ BranchWeights.push_back (FctID == 0U ? 0U
172
+ : GuidToCounter.lookup_or (FctID, 0U ));
130
173
BranchInst::Create (Tail, B);
131
174
if (PHI)
132
175
PHI->addIncoming (Call, B);
@@ -136,6 +179,13 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
136
179
return OptimizationRemark (DEBUG_TYPE, " ReplacedJumpTableWithSwitch" , CB)
137
180
<< " expanded indirect call into switch" ;
138
181
});
182
+ if (HadProfile && !ProfcheckDisableMetadataFixes) {
183
+ // At least one of the targets must've been taken.
184
+ assert (llvm::any_of (BranchWeights, [](uint64_t V) { return V != 0 ; }));
185
+ setProfMetadata (F.getParent (), Switch, BranchWeights,
186
+ *llvm::max_element (BranchWeights));
187
+ } else
188
+ setExplicitlyUnknownBranchWeights (*Switch);
139
189
if (PHI)
140
190
CB->replaceAllUsesWith (PHI);
141
191
CB->eraseFromParent ();
@@ -150,6 +200,15 @@ PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
150
200
PostDominatorTree *PDT = AM.getCachedResult <PostDominatorTreeAnalysis>(F);
151
201
DomTreeUpdater DTU (DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);
152
202
bool Changed = false ;
203
+ InstrProfSymtab Symtab;
204
+ if (auto E = Symtab.create (*F.getParent ()))
205
+ F.getContext ().emitError (
206
+ " Could not create indirect call table, likely corrupted IR" +
207
+ toString (std::move (E)));
208
+ DenseMap<const Function *, GlobalValue::GUID> FToGuid;
209
+ for (const auto &[G, FPtr] : Symtab.getIDToNameMap ())
210
+ FToGuid.insert ({FPtr, G});
211
+
153
212
for (BasicBlock &BB : make_early_inc_range (F)) {
154
213
BasicBlock *CurrentBB = &BB;
155
214
while (CurrentBB) {
@@ -170,7 +229,12 @@ PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
170
229
std::optional<JumpTableTy> JumpTable = parseJumpTable (GEP, PtrTy);
171
230
if (!JumpTable)
172
231
continue ;
173
- SplittedOutTail = expandToSwitch (Call, *JumpTable, DTU, ORE);
232
+ SplittedOutTail = expandToSwitch (
233
+ Call, *JumpTable, DTU, ORE, [&](const Function &Fct) {
234
+ if (Fct.getMetadata (AssignGUIDPass::GUIDMetadataName))
235
+ return AssignGUIDPass::getGUID (Fct);
236
+ return FToGuid.lookup_or (&Fct, 0U );
237
+ });
174
238
Changed = true ;
175
239
break ;
176
240
}
0 commit comments