Skip to content

Commit 6a998c8

Browse files
committed
[JTS] Propagate profile info
1 parent 09505b1 commit 6a998c8

File tree

3 files changed

+68
-8
lines changed

3 files changed

+68
-8
lines changed

llvm/include/llvm/ProfileData/InstrProf.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,10 @@ class InstrProfSymtab {
660660
return Error::success();
661661
}
662662

663+
const std::vector<std::pair<uint64_t, Function *>> &getIDToNameMap() const {
664+
return MD5FuncMap;
665+
}
666+
663667
const StringSet<> &getVTableNames() const { return VTableNames; }
664668

665669
/// Map a function address to its name's MD5 hash. This interface

llvm/lib/Transforms/Scalar/JumpTableToSwitch.cpp

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,24 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "llvm/Transforms/Scalar/JumpTableToSwitch.h"
10+
#include "llvm/ADT/DenseSet.h"
11+
#include "llvm/ADT/STLExtras.h"
12+
#include "llvm/ADT/SmallSet.h"
1013
#include "llvm/ADT/SmallVector.h"
1114
#include "llvm/Analysis/ConstantFolding.h"
15+
#include "llvm/Analysis/CtxProfAnalysis.h"
1216
#include "llvm/Analysis/DomTreeUpdater.h"
1317
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
1418
#include "llvm/Analysis/PostDominators.h"
1519
#include "llvm/IR/IRBuilder.h"
20+
#include "llvm/IR/LLVMContext.h"
21+
#include "llvm/IR/ProfDataUtils.h"
22+
#include "llvm/ProfileData/InstrProf.h"
1623
#include "llvm/Support/CommandLine.h"
24+
#include "llvm/Support/Error.h"
25+
#include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
1726
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
27+
#include <limits>
1828

1929
using namespace llvm;
2030

@@ -90,9 +100,11 @@ static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP,
90100
return JumpTable;
91101
}
92102

93-
static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
94-
DomTreeUpdater &DTU,
95-
OptimizationRemarkEmitter &ORE) {
103+
static BasicBlock *
104+
expandToSwitch(CallBase *CB, const JumpTableTy &JT, DomTreeUpdater &DTU,
105+
OptimizationRemarkEmitter &ORE,
106+
llvm::function_ref<GlobalValue::GUID(const Function &)>
107+
GetGuidForFunction) {
96108
const bool IsVoid = CB->getType() == Type::getVoidTy(CB->getContext());
97109

98110
SmallVector<DominatorTree::UpdateType, 8> DTUpdates;
@@ -115,7 +127,26 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
115127
IRBuilder<> BuilderTail(CB);
116128
PHINode *PHI =
117129
IsVoid ? nullptr : BuilderTail.CreatePHI(CB->getType(), JT.Funcs.size());
118-
130+
const auto *ProfMD = CB->getMetadata(LLVMContext::MD_prof);
131+
132+
SmallVector<uint64_t> BranchWeights;
133+
DenseMap<GlobalValue::GUID, uint64_t> GuidToCounter;
134+
const bool HadProfile = isValueProfileMD(ProfMD);
135+
if (HadProfile) {
136+
BranchWeights.reserve(JT.Funcs.size() + 1);
137+
// The first is the default target, which is the unreachable block above.
138+
BranchWeights.push_back(0U);
139+
uint64_t TotalCount = 0;
140+
auto Targets = getValueProfDataFromInst(
141+
*CB, InstrProfValueKind::IPVK_IndirectCallTarget,
142+
std::numeric_limits<uint32_t>::max(), TotalCount);
143+
144+
for (const auto &[G, C] : Targets) {
145+
auto It = GuidToCounter.insert({G, C});
146+
if (!It.second)
147+
It.second += C; // Unexpected, but likely the right way forward
148+
}
149+
}
119150
for (auto [Index, Func] : llvm::enumerate(JT.Funcs)) {
120151
BasicBlock *B = BasicBlock::Create(Func->getContext(),
121152
"call." + Twine(Index), &F, Tail);
@@ -127,6 +158,9 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
127158
Call->insertInto(B, B->end());
128159
Switch->addCase(
129160
cast<ConstantInt>(ConstantInt::get(JT.Index->getType(), Index)), B);
161+
GlobalValue::GUID FctID = GetGuidForFunction(*Func);
162+
BranchWeights.push_back(FctID == 0U ? 0U
163+
: GuidToCounter.lookup_or(FctID, 0U));
130164
BranchInst::Create(Tail, B);
131165
if (PHI)
132166
PHI->addIncoming(Call, B);
@@ -136,6 +170,11 @@ static BasicBlock *expandToSwitch(CallBase *CB, const JumpTableTy &JT,
136170
return OptimizationRemark(DEBUG_TYPE, "ReplacedJumpTableWithSwitch", CB)
137171
<< "expanded indirect call into switch";
138172
});
173+
if (HadProfile)
174+
setProfMetadata(F.getParent(), Switch, BranchWeights,
175+
*llvm::max_element(BranchWeights));
176+
else
177+
setExplicitlyUnknownBranchWeights(*Switch);
139178
if (PHI)
140179
CB->replaceAllUsesWith(PHI);
141180
CB->eraseFromParent();
@@ -150,6 +189,15 @@ PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
150189
PostDominatorTree *PDT = AM.getCachedResult<PostDominatorTreeAnalysis>(F);
151190
DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);
152191
bool Changed = false;
192+
InstrProfSymtab Symtab;
193+
if (auto E = Symtab.create(*F.getParent()))
194+
F.getContext().emitError(
195+
"Could not create indirect call table, likely corrupted IR" +
196+
toString(std::move(E)));
197+
DenseMap<const Function *, GlobalValue::GUID> FToGuid;
198+
for (const auto &[G, FPtr] : Symtab.getIDToNameMap())
199+
FToGuid.insert({FPtr, G});
200+
153201
for (BasicBlock &BB : make_early_inc_range(F)) {
154202
BasicBlock *CurrentBB = &BB;
155203
while (CurrentBB) {
@@ -170,7 +218,12 @@ PreservedAnalyses JumpTableToSwitchPass::run(Function &F,
170218
std::optional<JumpTableTy> JumpTable = parseJumpTable(GEP, PtrTy);
171219
if (!JumpTable)
172220
continue;
173-
SplittedOutTail = expandToSwitch(Call, *JumpTable, DTU, ORE);
221+
SplittedOutTail = expandToSwitch(
222+
Call, *JumpTable, DTU, ORE, [&](const Function &Fct) {
223+
if (Fct.getMetadata(AssignGUIDPass::GUIDMetadataName))
224+
return AssignGUIDPass::getGUID(Fct);
225+
return FToGuid.lookup_or(&Fct, 0U);
226+
});
174227
Changed = true;
175228
break;
176229
}

llvm/test/Transforms/JumpTableToSwitch/basic.ll

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44

55
@func_array = constant [2 x ptr] [ptr @func0, ptr @func1]
66

7-
define i32 @func0() {
7+
define i32 @func0() !guid !0 {
88
ret i32 1
99
}
1010

11-
define i32 @func1() {
11+
define i32 @func1() !guid !1 {
1212
ret i32 2
1313
}
1414

@@ -42,7 +42,7 @@ define i32 @function_with_jump_table(i32 %index) {
4242
;
4343
%gep = getelementptr inbounds [2 x ptr], ptr @func_array, i32 0, i32 %index
4444
%func_ptr = load ptr, ptr %gep
45-
%result = call i32 %func_ptr()
45+
%result = call i32 %func_ptr(), !prof !2
4646
ret i32 %result
4747
}
4848

@@ -226,3 +226,6 @@ define i32 @function_with_jump_table_addrspace_42(i32 %index) addrspace(42) {
226226
ret i32 %result
227227
}
228228

229+
!0 = !{i64 5678}
230+
!1 = !{i64 5555}
231+
!2 = !{!"VP", i32 0, i64 25, i64 5678, i64 20, i64 5555, i64 5}

0 commit comments

Comments
 (0)