7
7
// ===----------------------------------------------------------------------===//
8
8
9
9
#include " llvm/Transforms/Scalar/JumpTableToSwitch.h"
10
+ #include " llvm/ADT/DenseSet.h"
11
+ #include " llvm/ADT/STLExtras.h"
12
+ #include " llvm/ADT/SmallSet.h"
10
13
#include " llvm/ADT/SmallVector.h"
11
14
#include " llvm/Analysis/ConstantFolding.h"
15
+ #include " llvm/Analysis/CtxProfAnalysis.h"
12
16
#include " llvm/Analysis/DomTreeUpdater.h"
13
17
#include " llvm/Analysis/OptimizationRemarkEmitter.h"
14
18
#include " llvm/Analysis/PostDominators.h"
15
19
#include " llvm/IR/IRBuilder.h"
20
+ #include " llvm/IR/LLVMContext.h"
21
+ #include " llvm/IR/ProfDataUtils.h"
22
+ #include " llvm/ProfileData/InstrProf.h"
16
23
#include " llvm/Support/CommandLine.h"
24
+ #include " llvm/Support/Error.h"
25
+ #include " llvm/Transforms/Instrumentation/PGOInstrumentation.h"
17
26
#include " llvm/Transforms/Utils/BasicBlockUtils.h"
27
+ #include < limits>
18
28
19
29
using namespace llvm ;
20
30
@@ -90,9 +100,11 @@ static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP,
90
100
return JumpTable;
91
101
}
92
102
93
- static BasicBlock *expandToSwitch (CallBase *CB, const JumpTableTy &JT,
94
- DomTreeUpdater &DTU,
95
- OptimizationRemarkEmitter &ORE) {
103
+ static BasicBlock *
104
+ expandToSwitch (CallBase *CB, const JumpTableTy &JT, DomTreeUpdater &DTU,
105
+ OptimizationRemarkEmitter &ORE,
106
+ llvm::function_ref<GlobalValue::GUID(const Function &)>
107
+ GetGuidForFunction) {
96
108
const bool IsVoid = CB->getType () == Type::getVoidTy (CB->getContext ());
97
109
98
110
SmallVector<DominatorTree::UpdateType, 8 > DTUpdates;
@@ -115,7 +127,26 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
115
127
IRBuilder<> BuilderTail (CB);
116
128
PHINode *PHI =
117
129
IsVoid ? nullptr : BuilderTail.CreatePHI (CB->getType (), JT.Funcs .size ());
118
-
130
+ const auto *ProfMD = CB->getMetadata (LLVMContext::MD_prof);
131
+
132
+ SmallVector<uint64_t > BranchWeights;
133
+ DenseMap<GlobalValue::GUID, uint64_t > GuidToCounter;
134
+ const bool HadProfile = isValueProfileMD (ProfMD);
135
+ if (HadProfile) {
136
+ BranchWeights.reserve (JT.Funcs .size () + 1 );
137
+ // The first is the default target, which is the unreachable block above.
138
+ BranchWeights.push_back (0U );
139
+ uint64_t TotalCount = 0 ;
140
+ auto Targets = getValueProfDataFromInst (
141
+ *CB, InstrProfValueKind::IPVK_IndirectCallTarget,
142
+ std::numeric_limits<uint32_t >::max (), TotalCount);
143
+
144
+ for (const auto &[G, C] : Targets) {
145
+ auto It = GuidToCounter.insert ({G, C});
146
+ if (!It.second )
147
+ It.second += C; // Unexpected, but likely the right way forward
148
+ }
149
+ }
119
150
for (auto [Index, Func] : llvm::enumerate (JT.Funcs )) {
120
151
BasicBlock *B = BasicBlock::Create (Func->getContext (),
121
152
" call." + Twine (Index), &F, Tail);
@@ -127,6 +158,9 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
127
158
Call->insertInto (B, B->end ());
128
159
Switch->addCase (
129
160
cast<ConstantInt>(ConstantInt::get (JT.Index ->getType (), Index)), B);
161
+ GlobalValue::GUID FctID = GetGuidForFunction (*Func);
162
+ BranchWeights.push_back (FctID == 0U ? 0U
163
+ : GuidToCounter.lookup_or (FctID, 0U ));
130
164
BranchInst::Create (Tail, B);
131
165
if (PHI)
132
166
PHI->addIncoming (Call, B);
@@ -136,6 +170,11 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
136
170
return OptimizationRemark (DEBUG_TYPE, " ReplacedJumpTableWithSwitch" , CB)
137
171
<< " expanded indirect call into switch" ;
138
172
});
173
+ if (HadProfile)
174
+ setProfMetadata (F.getParent (), Switch, BranchWeights,
175
+ *llvm::max_element (BranchWeights));
176
+ else
177
+ setExplicitlyUnknownBranchWeights (*Switch);
139
178
if (PHI)
140
179
CB->replaceAllUsesWith (PHI);
141
180
CB->eraseFromParent ();
@@ -150,6 +189,15 @@ PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
150
189
PostDominatorTree *PDT = AM.getCachedResult <PostDominatorTreeAnalysis>(F);
151
190
DomTreeUpdater DTU (DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);
152
191
bool Changed = false ;
192
+ InstrProfSymtab Symtab;
193
+ if (auto E = Symtab.create (*F.getParent ()))
194
+ F.getContext ().emitError (
195
+ " Could not create indirect call table, likely corrupted IR" +
196
+ toString (std::move (E)));
197
+ DenseMap<const Function *, GlobalValue::GUID> FToGuid;
198
+ for (const auto &[G, FPtr] : Symtab.getIDToNameMap ())
199
+ FToGuid.insert ({FPtr, G});
200
+
153
201
for (BasicBlock &BB : make_early_inc_range (F)) {
154
202
BasicBlock *CurrentBB = &BB;
155
203
while (CurrentBB) {
@@ -170,7 +218,12 @@ PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
170
218
std::optional<JumpTableTy> JumpTable = parseJumpTable (GEP, PtrTy);
171
219
if (!JumpTable)
172
220
continue ;
173
- SplittedOutTail = expandToSwitch (Call, *JumpTable, DTU, ORE);
221
+ SplittedOutTail = expandToSwitch (
222
+ Call, *JumpTable, DTU, ORE, [&](const Function &Fct) {
223
+ if (Fct.getMetadata (AssignGUIDPass::GUIDMetadataName))
224
+ return AssignGUIDPass::getGUID (Fct);
225
+ return FToGuid.lookup_or (&Fct, 0U );
226
+ });
174
227
Changed = true ;
175
228
break ;
176
229
}
0 commit comments