77// ===----------------------------------------------------------------------===//
88
99#include " llvm/Transforms/Scalar/JumpTableToSwitch.h"
10+ #include " llvm/ADT/DenseSet.h"
11+ #include " llvm/ADT/STLExtras.h"
12+ #include " llvm/ADT/SmallSet.h"
1013#include " llvm/ADT/SmallVector.h"
1114#include " llvm/Analysis/ConstantFolding.h"
15+ #include " llvm/Analysis/CtxProfAnalysis.h"
1216#include " llvm/Analysis/DomTreeUpdater.h"
1317#include " llvm/Analysis/OptimizationRemarkEmitter.h"
1418#include " llvm/Analysis/PostDominators.h"
1519#include " llvm/IR/IRBuilder.h"
20+ #include " llvm/IR/LLVMContext.h"
21+ #include " llvm/IR/ProfDataUtils.h"
22+ #include " llvm/ProfileData/InstrProf.h"
1623#include " llvm/Support/CommandLine.h"
24+ #include " llvm/Support/Error.h"
25+ #include " llvm/Transforms/Instrumentation/PGOInstrumentation.h"
1726#include " llvm/Transforms/Utils/BasicBlockUtils.h"
27+ #include < limits>
1828
1929using namespace llvm ;
2030
@@ -33,6 +43,8 @@ static cl::opt<unsigned> FunctionSizeThreshold(
3343 " or equal than this threshold." ),
3444 cl::init(50 ));
3545
46+ extern cl::opt<bool > ProfcheckDisableMetadataFixes;
47+
3648#define DEBUG_TYPE " jump-table-to-switch"
3749
3850namespace {
@@ -90,9 +102,11 @@ static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP,
90102 return JumpTable;
91103}
92104
93- static BasicBlock *expandToSwitch (CallBase *CB, const JumpTableTy &JT,
94- DomTreeUpdater &DTU,
95- OptimizationRemarkEmitter &ORE) {
105+ static BasicBlock *
106+ expandToSwitch (CallBase *CB, const JumpTableTy &JT, DomTreeUpdater &DTU,
107+ OptimizationRemarkEmitter &ORE,
108+ llvm::function_ref<GlobalValue::GUID(const Function &)>
109+ GetGuidForFunction) {
96110 const bool IsVoid = CB->getType () == Type::getVoidTy (CB->getContext ());
97111
98112 SmallVector<DominatorTree::UpdateType, 8 > DTUpdates;
@@ -115,7 +129,31 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
115129 IRBuilder<> BuilderTail (CB);
116130 PHINode *PHI =
117131 IsVoid ? nullptr : BuilderTail.CreatePHI (CB->getType (), JT.Funcs .size ());
118-
132+ const auto *ProfMD = CB->getMetadata (LLVMContext::MD_prof);
133+
134+ SmallVector<uint64_t > BranchWeights;
135+ DenseMap<GlobalValue::GUID, uint64_t > GuidToCounter;
136+ const bool HadProfile = isValueProfileMD (ProfMD);
137+ if (HadProfile) {
138+ // The assumptions, coming in, are that the functions in JT.Funcs are
139+ // defined in this module (from parseJumpTable).
140+ assert (llvm::all_of (
141+ JT.Funcs , [](const Function *F) { return F && !F->isDeclaration (); }));
142+ BranchWeights.reserve (JT.Funcs .size () + 1 );
143+ // The first is the default target, which is the unreachable block created
144+ // above.
145+ BranchWeights.push_back (0U );
146+ uint64_t TotalCount = 0 ;
147+ auto Targets = getValueProfDataFromInst (
148+ *CB, InstrProfValueKind::IPVK_IndirectCallTarget,
149+ std::numeric_limits<uint32_t >::max (), TotalCount);
150+
151+ for (const auto &[G, C] : Targets) {
152+ auto It = GuidToCounter.insert ({G, C});
153+ assert (It.second );
154+ (void )It;
155+ }
156+ }
119157 for (auto [Index, Func] : llvm::enumerate (JT.Funcs )) {
120158 BasicBlock *B = BasicBlock::Create (Func->getContext (),
121159 " call." + Twine (Index), &F, Tail);
@@ -127,6 +165,11 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
127165 Call->insertInto (B, B->end ());
128166 Switch->addCase (
129167 cast<ConstantInt>(ConstantInt::get (JT.Index ->getType (), Index)), B);
168+ GlobalValue::GUID FctID = GetGuidForFunction (*Func);
169+ // It'd be OK to _not_ find target functions in GuidToCounter, e.g. suppose
170+ // just some of the jump targets are taken (for the given profile).
171+ BranchWeights.push_back (FctID == 0U ? 0U
172+ : GuidToCounter.lookup_or (FctID, 0U ));
130173 BranchInst::Create (Tail, B);
131174 if (PHI)
132175 PHI->addIncoming (Call, B);
@@ -136,6 +179,13 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
136179 return OptimizationRemark (DEBUG_TYPE, " ReplacedJumpTableWithSwitch" , CB)
137180 << " expanded indirect call into switch" ;
138181 });
182+ if (HadProfile && !ProfcheckDisableMetadataFixes) {
183+ // At least one of the targets must've been taken.
184+ assert (llvm::any_of (BranchWeights, [](uint64_t V) { return V != 0 ; }));
185+ setProfMetadata (F.getParent (), Switch, BranchWeights,
186+ *llvm::max_element (BranchWeights));
187+ } else
188+ setExplicitlyUnknownBranchWeights (*Switch);
139189 if (PHI)
140190 CB->replaceAllUsesWith (PHI);
141191 CB->eraseFromParent ();
@@ -150,6 +200,15 @@ PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
150200 PostDominatorTree *PDT = AM.getCachedResult <PostDominatorTreeAnalysis>(F);
151201 DomTreeUpdater DTU (DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);
152202 bool Changed = false ;
203+ InstrProfSymtab Symtab;
204+ if (auto E = Symtab.create (*F.getParent ()))
205+ F.getContext ().emitError (
206+ " Could not create indirect call table, likely corrupted IR" +
207+ toString (std::move (E)));
208+ DenseMap<const Function *, GlobalValue::GUID> FToGuid;
209+ for (const auto &[G, FPtr] : Symtab.getIDToNameMap ())
210+ FToGuid.insert ({FPtr, G});
211+
153212 for (BasicBlock &BB : make_early_inc_range (F)) {
154213 BasicBlock *CurrentBB = &BB;
155214 while (CurrentBB) {
@@ -170,7 +229,12 @@ PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
170229 std::optional<JumpTableTy> JumpTable = parseJumpTable (GEP, PtrTy);
171230 if (!JumpTable)
172231 continue ;
173- SplittedOutTail = expandToSwitch (Call, *JumpTable, DTU, ORE);
232+ SplittedOutTail = expandToSwitch (
233+ Call, *JumpTable, DTU, ORE, [&](const Function &Fct) {
234+ if (Fct.getMetadata (AssignGUIDPass::GUIDMetadataName))
235+ return AssignGUIDPass::getGUID (Fct);
236+ return FToGuid.lookup_or (&Fct, 0U );
237+ });
174238 Changed = true ;
175239 break ;
176240 }
0 commit comments