77// ===----------------------------------------------------------------------===//
88
99#include " llvm/Transforms/Scalar/JumpTableToSwitch.h"
10+ #include " llvm/ADT/DenseSet.h"
11+ #include " llvm/ADT/STLExtras.h"
12+ #include " llvm/ADT/SmallSet.h"
1013#include " llvm/ADT/SmallVector.h"
1114#include " llvm/Analysis/ConstantFolding.h"
15+ #include " llvm/Analysis/CtxProfAnalysis.h"
1216#include " llvm/Analysis/DomTreeUpdater.h"
1317#include " llvm/Analysis/OptimizationRemarkEmitter.h"
1418#include " llvm/Analysis/PostDominators.h"
1519#include " llvm/IR/IRBuilder.h"
20+ #include " llvm/IR/LLVMContext.h"
21+ #include " llvm/IR/ProfDataUtils.h"
22+ #include " llvm/ProfileData/InstrProf.h"
1623#include " llvm/Support/CommandLine.h"
24+ #include " llvm/Support/Error.h"
25+ #include " llvm/Transforms/Instrumentation/PGOInstrumentation.h"
1726#include " llvm/Transforms/Utils/BasicBlockUtils.h"
27+ #include < limits>
1828
1929using namespace llvm ;
2030
@@ -90,9 +100,11 @@ static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP,
90100 return JumpTable;
91101}
92102
93- static BasicBlock *expandToSwitch (CallBase *CB, const JumpTableTy &JT,
94- DomTreeUpdater &DTU,
95- OptimizationRemarkEmitter &ORE) {
103+ static BasicBlock *
104+ expandToSwitch (CallBase *CB, const JumpTableTy &JT, DomTreeUpdater &DTU,
105+ OptimizationRemarkEmitter &ORE,
106+ llvm::function_ref<GlobalValue::GUID(const Function &)>
107+ GetGuidForFunction) {
96108 const bool IsVoid = CB->getType () == Type::getVoidTy (CB->getContext ());
97109
98110 SmallVector<DominatorTree::UpdateType, 8 > DTUpdates;
@@ -115,7 +127,26 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
115127 IRBuilder<> BuilderTail (CB);
116128 PHINode *PHI =
117129 IsVoid ? nullptr : BuilderTail.CreatePHI (CB->getType (), JT.Funcs .size ());
118-
130+ const auto *ProfMD = CB->getMetadata (LLVMContext::MD_prof);
131+
132+ SmallVector<uint64_t > BranchWeights;
133+ DenseMap<GlobalValue::GUID, uint64_t > GuidToCounter;
134+ const bool HadProfile = isValueProfileMD (ProfMD);
135+ if (HadProfile) {
136+ BranchWeights.reserve (JT.Funcs .size () + 1 );
137+ // The first is the default target, which is the unreachable block above.
138+ BranchWeights.push_back (0U );
139+ uint64_t TotalCount = 0 ;
140+ auto Targets = getValueProfDataFromInst (
141+ *CB, InstrProfValueKind::IPVK_IndirectCallTarget,
142+ std::numeric_limits<uint32_t >::max (), TotalCount);
143+
144+ for (const auto &[G, C] : Targets) {
145+ auto It = GuidToCounter.insert ({G, C});
146+ if (!It.second )
147+ It.second += C; // Unexpected, but likely the right way forward
148+ }
149+ }
119150 for (auto [Index, Func] : llvm::enumerate (JT.Funcs )) {
120151 BasicBlock *B = BasicBlock::Create (Func->getContext (),
121152 " call." + Twine (Index), &F, Tail);
@@ -127,6 +158,9 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
127158 Call->insertInto (B, B->end ());
128159 Switch->addCase (
129160 cast<ConstantInt>(ConstantInt::get (JT.Index ->getType (), Index)), B);
161+ GlobalValue::GUID FctID = GetGuidForFunction (*Func);
162+ BranchWeights.push_back (FctID == 0U ? 0U
163+ : GuidToCounter.lookup_or (FctID, 0U ));
130164 BranchInst::Create (Tail, B);
131165 if (PHI)
132166 PHI->addIncoming (Call, B);
@@ -136,6 +170,11 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
136170 return OptimizationRemark (DEBUG_TYPE, " ReplacedJumpTableWithSwitch" , CB)
137171 << " expanded indirect call into switch" ;
138172 });
173+ if (HadProfile)
174+ setProfMetadata (F.getParent (), Switch, BranchWeights,
175+ *llvm::max_element (BranchWeights));
176+ else
177+ setExplicitlyUnknownBranchWeights (*Switch);
139178 if (PHI)
140179 CB->replaceAllUsesWith (PHI);
141180 CB->eraseFromParent ();
@@ -150,6 +189,15 @@ PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
150189 PostDominatorTree *PDT = AM.getCachedResult <PostDominatorTreeAnalysis>(F);
151190 DomTreeUpdater DTU (DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);
152191 bool Changed = false ;
192+ InstrProfSymtab Symtab;
193+ if (auto E = Symtab.create (*F.getParent ()))
194+ F.getContext ().emitError (
195+ " Could not create indirect call table, likely corrupted IR" +
196+ toString (std::move (E)));
197+ DenseMap<const Function *, GlobalValue::GUID> FToGuid;
198+ for (const auto &[G, FPtr] : Symtab.getIDToNameMap ())
199+ FToGuid.insert ({FPtr, G});
200+
153201 for (BasicBlock &BB : make_early_inc_range (F)) {
154202 BasicBlock *CurrentBB = &BB;
155203 while (CurrentBB) {
@@ -170,7 +218,12 @@ PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
170218 std::optional<JumpTableTy> JumpTable = parseJumpTable (GEP, PtrTy);
171219 if (!JumpTable)
172220 continue ;
173- SplittedOutTail = expandToSwitch (Call, *JumpTable, DTU, ORE);
221+ SplittedOutTail = expandToSwitch (
222+ Call, *JumpTable, DTU, ORE, [&](const Function &Fct) {
223+ if (Fct.getMetadata (AssignGUIDPass::GUIDMetadataName))
224+ return AssignGUIDPass::getGUID (Fct);
225+ return FToGuid.lookup_or (&Fct, 0U );
226+ });
174227 Changed = true ;
175228 break ;
176229 }
0 commit comments