diff --git a/llvm/include/llvm/ProfileData/InstrProf.h b/llvm/include/llvm/ProfileData/InstrProf.h index bab1963dba22e..85a9efe73855b 100644 --- a/llvm/include/llvm/ProfileData/InstrProf.h +++ b/llvm/include/llvm/ProfileData/InstrProf.h @@ -665,6 +665,10 @@ class InstrProfSymtab { return Error::success(); } + const std::vector> &getIDToNameMap() const { + return MD5FuncMap; + } + const StringSet<> &getVTableNames() const { return VTableNames; } /// Map a function address to its name's MD5 hash. This interface diff --git a/llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp b/llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp index 7f99cd2060a9d..6719ce64b96b6 100644 --- a/llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp +++ b/llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp @@ -7,14 +7,24 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/JumpTableToSwitch.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/CtxProfAnalysis.h" #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/ProfDataUtils.h" +#include "llvm/ProfileData/InstrProf.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Error.h" +#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include using namespace llvm; @@ -33,6 +43,8 @@ static cl::opt FunctionSizeThreshold( "or equal than this threshold."), cl::init(50)); +extern cl::opt ProfcheckDisableMetadataFixes; + #define DEBUG_TYPE "jump-table-to-switch" namespace { @@ -90,9 +102,11 @@ static std::optional parseJumpTable(GetElementPtrInst *GEP, return JumpTable; } -static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT, - DomTreeUpdater &DTU, - OptimizationRemarkEmitter &ORE) { +static BasicBlock * +expandToSwitch(CallBase *CB, const JumpTableTy &JT, DomTreeUpdater &DTU, + OptimizationRemarkEmitter &ORE, + llvm::function_ref + GetGuidForFunction) { const bool IsVoid = CB->getType() == Type::getVoidTy(CB->getContext()); SmallVector DTUpdates; @@ -115,7 +129,31 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT, IRBuilder<> BuilderTail(CB); PHINode *PHI = IsVoid ? nullptr : BuilderTail.CreatePHI(CB->getType(), JT.Funcs.size()); - + const auto *ProfMD = CB->getMetadata(LLVMContext::MD_prof); + + SmallVector BranchWeights; + DenseMap GuidToCounter; + const bool HadProfile = isValueProfileMD(ProfMD); + if (HadProfile) { + // The assumptions, coming in, are that the functions in JT.Funcs are + // defined in this module (from parseJumpTable). + assert(llvm::all_of( + JT.Funcs, [](const Function *F) { return F && !F->isDeclaration(); })); + BranchWeights.reserve(JT.Funcs.size() + 1); + // The first is the default target, which is the unreachable block created + // above. + BranchWeights.push_back(0U); + uint64_t TotalCount = 0; + auto Targets = getValueProfDataFromInst( + *CB, InstrProfValueKind::IPVK_IndirectCallTarget, + std::numeric_limits::max(), TotalCount); + + for (const auto &[G, C] : Targets) { + auto It = GuidToCounter.insert({G, C}); + assert(It.second); + (void)It; + } + } for (auto [Index, Func] : llvm::enumerate(JT.Funcs)) { BasicBlock *B = BasicBlock::Create(Func->getContext(), "call." + Twine(Index), &F, Tail); @@ -127,6 +165,11 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT, Call->insertInto(B, B->end()); Switch->addCase( cast(ConstantInt::get(JT.Index->getType(), Index)), B); + GlobalValue::GUID FctID = GetGuidForFunction(*Func); + // It'd be OK to _not_ find target functions in GuidToCounter, e.g. suppose + // just some of the jump targets are taken (for the given profile). + BranchWeights.push_back(FctID == 0U ? 0U + : GuidToCounter.lookup_or(FctID, 0U)); BranchInst::Create(Tail, B); if (PHI) PHI->addIncoming(Call, B); @@ -136,6 +179,13 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT, return OptimizationRemark(DEBUG_TYPE, "ReplacedJumpTableWithSwitch", CB) << "expanded indirect call into switch"; }); + if (HadProfile && !ProfcheckDisableMetadataFixes) { + // At least one of the targets must've been taken. + assert(llvm::any_of(BranchWeights, [](uint64_t V) { return V != 0; })); + setProfMetadata(F.getParent(), Switch, BranchWeights, + *llvm::max_element(BranchWeights)); + } else + setExplicitlyUnknownBranchWeights(*Switch); if (PHI) CB->replaceAllUsesWith(PHI); CB->eraseFromParent(); @@ -150,6 +200,15 @@ PreservedAnalyses JumpTableToSwitchPass::run(Function &F, PostDominatorTree *PDT = AM.getCachedResult(F); DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy); bool Changed = false; + InstrProfSymtab Symtab; + if (auto E = Symtab.create(*F.getParent())) + F.getContext().emitError( + "Could not create indirect call table, likely corrupted IR" + + toString(std::move(E))); + DenseMap FToGuid; + for (const auto &[G, FPtr] : Symtab.getIDToNameMap()) + FToGuid.insert({FPtr, G}); + for (BasicBlock &BB : make_early_inc_range(F)) { BasicBlock *CurrentBB = &BB; while (CurrentBB) { @@ -170,7 +229,12 @@ PreservedAnalyses JumpTableToSwitchPass::run(Function &F, std::optional JumpTable = parseJumpTable(GEP, PtrTy); if (!JumpTable) continue; - SplittedOutTail = expandToSwitch(Call, *JumpTable, DTU, ORE); + SplittedOutTail = expandToSwitch( + Call, *JumpTable, DTU, ORE, [&](const Function &Fct) { + if (Fct.getMetadata(AssignGUIDPass::GUIDMetadataName)) + return AssignGUIDPass::getGUID(Fct); + return FToGuid.lookup_or(&Fct, 0U); + }); Changed = true; break; } diff --git a/llvm/test/Transforms/JumpTableToSwitch/basic.ll b/llvm/test/Transforms/JumpTableToSwitch/basic.ll index 321f837077ab6..577c2adaf5afa 100644 --- a/llvm/test/Transforms/JumpTableToSwitch/basic.ll +++ b/llvm/test/Transforms/JumpTableToSwitch/basic.ll @@ -4,11 +4,11 @@ @func_array = constant [2 x ptr] [ptr @func0, ptr @func1] -define i32 @func0() { +define i32 @func0() !guid !0 { ret i32 1 } -define i32 @func1() { +define i32 @func1() !guid !1 { ret i32 2 } @@ -42,7 +42,7 @@ define i32 @function_with_jump_table(i32 %index) { ; %gep = getelementptr inbounds [2 x ptr], ptr @func_array, i32 0, i32 %index %func_ptr = load ptr, ptr %gep - %result = call i32 %func_ptr() + %result = call i32 %func_ptr(), !prof !2 ret i32 %result } @@ -226,3 +226,6 @@ define i32 @function_with_jump_table_addrspace_42(i32 %index) addrspace(42) { ret i32 %result } +!0 = !{i64 5678} +!1 = !{i64 5555} +!2 = !{!"VP", i32 0, i64 25, i64 5678, i64 20, i64 5555, i64 5} \ No newline at end of file