Skip to content

Commit f5d2843

Browse files
authored
[JTS] Propagate profile info (#153305)
If the indirect call target being recognized as a jump table has profile info, we can accurately synthesize the branch weights of the switch that replaces the indirect call. Otherwise we insert the "unknown" `MD_prof` to indicate this is the best we can do here. Part of Issue #147390
1 parent c202d2f commit f5d2843

File tree

3 files changed

+79
-8
lines changed

3 files changed

+79
-8
lines changed

llvm/include/llvm/ProfileData/InstrProf.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,10 @@ class InstrProfSymtab {
665665
return Error::success();
666666
}
667667

668+
const std::vector<std::pair<uint64_t, Function *>> &getIDToNameMap() const {
669+
return MD5FuncMap;
670+
}
671+
668672
const StringSet<> &getVTableNames() const { return VTableNames; }
669673

670674
/// Map a function address to its name's MD5 hash. This interface

llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,24 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "llvm/Transforms/Scalar/JumpTableToSwitch.h"
10+
#include "llvm/ADT/DenseSet.h"
11+
#include "llvm/ADT/STLExtras.h"
12+
#include "llvm/ADT/SmallSet.h"
1013
#include "llvm/ADT/SmallVector.h"
1114
#include "llvm/Analysis/ConstantFolding.h"
15+
#include "llvm/Analysis/CtxProfAnalysis.h"
1216
#include "llvm/Analysis/DomTreeUpdater.h"
1317
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
1418
#include "llvm/Analysis/PostDominators.h"
1519
#include "llvm/IR/IRBuilder.h"
20+
#include "llvm/IR/LLVMContext.h"
21+
#include "llvm/IR/ProfDataUtils.h"
22+
#include "llvm/ProfileData/InstrProf.h"
1623
#include "llvm/Support/CommandLine.h"
24+
#include "llvm/Support/Error.h"
25+
#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
1726
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
27+
#include <limits>
1828

1929
using namespace llvm;
2030

@@ -33,6 +43,8 @@ static cl::opt<unsigned> FunctionSizeThreshold(
3343
"or equal than this threshold."),
3444
cl::init(50));
3545

46+
extern cl::opt<bool> ProfcheckDisableMetadataFixes;
47+
3648
#define DEBUG_TYPE "jump-table-to-switch"
3749

3850
namespace {
@@ -90,9 +102,11 @@ static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP,
90102
return JumpTable;
91103
}
92104

93-
static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
94-
DomTreeUpdater &DTU,
95-
OptimizationRemarkEmitter &ORE) {
105+
static BasicBlock *
106+
expandToSwitch(CallBase *CB, const JumpTableTy &JT, DomTreeUpdater &DTU,
107+
OptimizationRemarkEmitter &ORE,
108+
llvm::function_ref<GlobalValue::GUID(const Function &)>
109+
GetGuidForFunction) {
96110
const bool IsVoid = CB->getType() == Type::getVoidTy(CB->getContext());
97111

98112
SmallVector<DominatorTree::UpdateType, 8> DTUpdates;
@@ -115,7 +129,31 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
115129
IRBuilder<> BuilderTail(CB);
116130
PHINode *PHI =
117131
IsVoid ? nullptr : BuilderTail.CreatePHI(CB->getType(), JT.Funcs.size());
118-
132+
const auto *ProfMD = CB->getMetadata(LLVMContext::MD_prof);
133+
134+
SmallVector<uint64_t> BranchWeights;
135+
DenseMap<GlobalValue::GUID, uint64_t> GuidToCounter;
136+
const bool HadProfile = isValueProfileMD(ProfMD);
137+
if (HadProfile) {
138+
// The assumptions, coming in, are that the functions in JT.Funcs are
139+
// defined in this module (from parseJumpTable).
140+
assert(llvm::all_of(
141+
JT.Funcs, [](const Function *F) { return F && !F->isDeclaration(); }));
142+
BranchWeights.reserve(JT.Funcs.size() + 1);
143+
// The first is the default target, which is the unreachable block created
144+
// above.
145+
BranchWeights.push_back(0U);
146+
uint64_t TotalCount = 0;
147+
auto Targets = getValueProfDataFromInst(
148+
*CB, InstrProfValueKind::IPVK_IndirectCallTarget,
149+
std::numeric_limits<uint32_t>::max(), TotalCount);
150+
151+
for (const auto &[G, C] : Targets) {
152+
auto It = GuidToCounter.insert({G, C});
153+
assert(It.second);
154+
(void)It;
155+
}
156+
}
119157
for (auto [Index, Func] : llvm::enumerate(JT.Funcs)) {
120158
BasicBlock *B = BasicBlock::Create(Func->getContext(),
121159
"call." + Twine(Index), &F, Tail);
@@ -127,6 +165,11 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
127165
Call->insertInto(B, B->end());
128166
Switch->addCase(
129167
cast<ConstantInt>(ConstantInt::get(JT.Index->getType(), Index)), B);
168+
GlobalValue::GUID FctID = GetGuidForFunction(*Func);
169+
// It'd be OK to _not_ find target functions in GuidToCounter, e.g. suppose
170+
// just some of the jump targets are taken (for the given profile).
171+
BranchWeights.push_back(FctID == 0U ? 0U
172+
: GuidToCounter.lookup_or(FctID, 0U));
130173
BranchInst::Create(Tail, B);
131174
if (PHI)
132175
PHI->addIncoming(Call, B);
@@ -136,6 +179,13 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
136179
return OptimizationRemark(DEBUG_TYPE, "ReplacedJumpTableWithSwitch", CB)
137180
<< "expanded indirect call into switch";
138181
});
182+
if (HadProfile && !ProfcheckDisableMetadataFixes) {
183+
// At least one of the targets must've been taken.
184+
assert(llvm::any_of(BranchWeights, [](uint64_t V) { return V != 0; }));
185+
setProfMetadata(F.getParent(), Switch, BranchWeights,
186+
*llvm::max_element(BranchWeights));
187+
} else
188+
setExplicitlyUnknownBranchWeights(*Switch);
139189
if (PHI)
140190
CB->replaceAllUsesWith(PHI);
141191
CB->eraseFromParent();
@@ -150,6 +200,15 @@ PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
150200
PostDominatorTree *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F);
151201
DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);
152202
bool Changed = false;
203+
InstrProfSymtab Symtab;
204+
if (auto E = Symtab.create(*F.getParent()))
205+
F.getContext().emitError(
206+
"Could not create indirect call table, likely corrupted IR" +
207+
toString(std::move(E)));
208+
DenseMap<const Function *, GlobalValue::GUID> FToGuid;
209+
for (const auto &[G, FPtr] : Symtab.getIDToNameMap())
210+
FToGuid.insert({FPtr, G});
211+
153212
for (BasicBlock &BB : make_early_inc_range(F)) {
154213
BasicBlock *CurrentBB = &BB;
155214
while (CurrentBB) {
@@ -170,7 +229,12 @@ PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
170229
std::optional<JumpTableTy> JumpTable = parseJumpTable(GEP, PtrTy);
171230
if (!JumpTable)
172231
continue;
173-
SplittedOutTail = expandToSwitch(Call, *JumpTable, DTU, ORE);
232+
SplittedOutTail = expandToSwitch(
233+
Call, *JumpTable, DTU, ORE, [&](const Function &Fct) {
234+
if (Fct.getMetadata(AssignGUIDPass::GUIDMetadataName))
235+
return AssignGUIDPass::getGUID(Fct);
236+
return FToGuid.lookup_or(&Fct, 0U);
237+
});
174238
Changed = true;
175239
break;
176240
}

llvm/test/Transforms/JumpTableToSwitch/basic.ll

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44

55
@func_array = constant [2 x ptr] [ptr @func0, ptr @func1]
66

7-
define i32 @func0() {
7+
define i32 @func0() !guid !0 {
88
ret i32 1
99
}
1010

11-
define i32 @func1() {
11+
define i32 @func1() !guid !1 {
1212
ret i32 2
1313
}
1414

@@ -42,7 +42,7 @@ define i32 @function_with_jump_table(i32 %index) {
4242
;
4343
%gep = getelementptr inbounds [2 x ptr], ptr @func_array, i32 0, i32 %index
4444
%func_ptr = load ptr, ptr %gep
45-
%result = call i32 %func_ptr()
45+
%result = call i32 %func_ptr(), !prof !2
4646
ret i32 %result
4747
}
4848

@@ -226,3 +226,6 @@ define i32 @function_with_jump_table_addrspace_42(i32 %index) addrspace(42) {
226226
ret i32 %result
227227
}
228228

229+
!0 = !{i64 5678}
230+
!1 = !{i64 5555}
231+
!2 = !{!"VP", i32 0, i64 25, i64 5678, i64 20, i64 5555, i64 5}

0 commit comments

Comments
 (0)