Skip to content

Commit 59546ec

Browse files
committed
[WPD] set the function entry count
1 parent a8680be commit 59546ec

File tree

5 files changed

+68
-27
lines changed

5 files changed

+68
-27
lines changed

llvm/include/llvm/IR/ProfDataUtils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ inline uint32_t scaleBranchCount(uint64_t Count, uint64_t Scale) {
179179
/// bugs where the pass forgets to transfer over or otherwise specify profile
180180
/// info.
181181
LLVM_ABI void setExplicitlyUnknownBranchWeights(Instruction &I);
182+
LLVM_ABI void setExplicitlyUnknownFunctionEntryCount(Function &I);
182183

183184
LLVM_ABI bool isExplicitlyUnknownBranchWeightsMetadata(const MDNode &MD);
184185
LLVM_ABI bool hasExplicitlyUnknownBranchWeights(const Instruction &I);

llvm/lib/IR/ProfDataUtils.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,14 @@ void setExplicitlyUnknownBranchWeights(Instruction &I) {
250250
MDB.createString(MDProfLabels::UnknownBranchWeightsMarker)));
251251
}
252252

253+
void setExplicitlyUnknownFunctionEntryCount(Function &F) {
254+
MDBuilder MDB(F.getContext());
255+
F.setMetadata(
256+
LLVMContext::MD_prof,
257+
MDNode::get(F.getContext(),
258+
MDB.createString(MDProfLabels::UnknownBranchWeightsMarker)));
259+
}
260+
253261
bool isExplicitlyUnknownBranchWeightsMetadata(const MDNode &MD) {
254262
if (MD.getNumOperands() != 1)
255263
return false;

llvm/lib/IR/Verifier.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2527,9 +2527,6 @@ void Verifier::verifyFunctionMetadata(
25272527
if (Pair.first == LLVMContext::MD_prof) {
25282528
MDNode *MD = Pair.second;
25292529
if (isExplicitlyUnknownBranchWeightsMetadata(*MD)) {
2530-
CheckFailed("'unknown' !prof metadata should appear only on "
2531-
"instructions supporting the 'branch_weights' metadata",
2532-
MD);
25332530
continue;
25342531
}
25352532
Check(MD->getNumOperands() >= 2,

llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
#include "llvm/ADT/Statistic.h"
6161
#include "llvm/Analysis/AssumptionCache.h"
6262
#include "llvm/Analysis/BasicAliasAnalysis.h"
63+
#include "llvm/Analysis/BlockFrequencyInfo.h"
6364
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
6465
#include "llvm/Analysis/TypeMetadataUtils.h"
6566
#include "llvm/Bitcode/BitcodeReader.h"
@@ -84,6 +85,7 @@
8485
#include "llvm/IR/Module.h"
8586
#include "llvm/IR/ModuleSummaryIndexYAML.h"
8687
#include "llvm/IR/PassManager.h"
88+
#include "llvm/IR/ProfDataUtils.h"
8789
#include "llvm/Support/Casting.h"
8890
#include "llvm/Support/CommandLine.h"
8991
#include "llvm/Support/Errc.h"
@@ -97,6 +99,7 @@
9799
#include "llvm/Transforms/Utils/CallPromotionUtils.h"
98100
#include "llvm/Transforms/Utils/Evaluator.h"
99101
#include <algorithm>
102+
#include <cmath>
100103
#include <cstddef>
101104
#include <map>
102105
#include <set>
@@ -169,6 +172,8 @@ static cl::list<std::string>
169172
cl::desc("Prevent function(s) from being devirtualized"),
170173
cl::Hidden, cl::CommaSeparated);
171174

175+
extern cl::opt<bool> ProfcheckDisableMetadataFixes;
176+
172177
/// With Clang, a pure virtual class's deleting destructor is emitted as a
173178
/// `llvm.trap` intrinsic followed by an unreachable IR instruction. In the
174179
/// context of whole program devirtualization, the deleting destructor of a pure
@@ -656,7 +661,7 @@ struct DevirtModule {
656661
VTableSlotInfo &SlotInfo,
657662
WholeProgramDevirtResolution *Res);
658663

659-
void applyICallBranchFunnel(VTableSlotInfo &SlotInfo, Constant *JT,
664+
void applyICallBranchFunnel(VTableSlotInfo &SlotInfo, Function &JT,
660665
bool &IsExported);
661666
void tryICallBranchFunnel(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
662667
VTableSlotInfo &SlotInfo,
@@ -1453,7 +1458,7 @@ void DevirtModule::tryICallBranchFunnel(
14531458

14541459
FunctionType *FT =
14551460
FunctionType::get(Type::getVoidTy(M.getContext()), {Int8PtrTy}, true);
1456-
Function *JT;
1461+
Function *JT = nullptr;
14571462
if (isa<MDString>(Slot.TypeID)) {
14581463
JT = Function::Create(FT, Function::ExternalLinkage,
14591464
M.getDataLayout().getProgramAddressSpace(),
@@ -1482,13 +1487,18 @@ void DevirtModule::tryICallBranchFunnel(
14821487
ReturnInst::Create(M.getContext(), nullptr, BB);
14831488

14841489
bool IsExported = false;
1485-
applyICallBranchFunnel(SlotInfo, JT, IsExported);
1490+
applyICallBranchFunnel(SlotInfo, *JT, IsExported);
14861491
if (IsExported)
14871492
Res->TheKind = WholeProgramDevirtResolution::BranchFunnel;
1493+
1494+
if (!JT->getEntryCount().has_value()) {
1495+
setExplicitlyUnknownFunctionEntryCount(*JT);
1496+
}
14881497
}
14891498

14901499
void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
1491-
Constant *JT, bool &IsExported) {
1500+
Function &JT, bool &IsExported) {
1501+
DenseMap<Function *, double> FunctionEntryCounts;
14921502
auto Apply = [&](CallSiteInfo &CSInfo) {
14931503
if (CSInfo.isExported())
14941504
IsExported = true;
@@ -1517,7 +1527,7 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
15171527
NumBranchFunnel++;
15181528
if (RemarksEnabled)
15191529
VCallSite.emitRemark("branch-funnel",
1520-
JT->stripPointerCasts()->getName(), OREGetter);
1530+
JT.stripPointerCasts()->getName(), OREGetter);
15211531

15221532
// Pass the address of the vtable in the nest register, which is r10 on
15231533
// x86_64.
@@ -1533,11 +1543,26 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
15331543
llvm::append_range(Args, CB.args());
15341544

15351545
CallBase *NewCS = nullptr;
1546+
if (!JT.isDeclaration() && !ProfcheckDisableMetadataFixes) {
1547+
auto &F = *CB.getCaller();
1548+
auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(F);
1549+
auto EC = BFI.getBlockFreq(&F.getEntryBlock());
1550+
auto CC = F.getEntryCount(/*AllowSynthetic=*/true);
1551+
double CallCount = 0.0;
1552+
if (EC.getFrequency() != 0 && CC && CC->getCount() != 0) {
1553+
double CallFreq =
1554+
static_cast<double>(
1555+
BFI.getBlockFreq(CB.getParent()).getFrequency()) /
1556+
EC.getFrequency();
1557+
CallCount = CallFreq * CC->getCount();
1558+
}
1559+
FunctionEntryCounts[&JT] += CallCount;
1560+
}
15361561
if (isa<CallInst>(CB))
1537-
NewCS = IRB.CreateCall(NewFT, JT, Args);
1562+
NewCS = IRB.CreateCall(NewFT, &JT, Args);
15381563
else
15391564
NewCS =
1540-
IRB.CreateInvoke(NewFT, JT, cast<InvokeInst>(CB).getNormalDest(),
1565+
IRB.CreateInvoke(NewFT, &JT, cast<InvokeInst>(CB).getNormalDest(),
15411566
cast<InvokeInst>(CB).getUnwindDest(), Args);
15421567
NewCS->setCallingConv(CB.getCallingConv());
15431568

@@ -1571,6 +1596,11 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
15711596
Apply(SlotInfo.CSInfo);
15721597
for (auto &P : SlotInfo.ConstCSInfo)
15731598
Apply(P.second);
1599+
for (auto &[F, C] : FunctionEntryCounts) {
1600+
assert(!F->getEntryCount(/*AllowSynthetic=*/true) &&
1601+
"Unexpected entry count for funnel that was freshly synthesized");
1602+
F->setEntryCount(static_cast<uint64_t>(std::round(C)));
1603+
}
15741604
}
15751605

15761606
bool DevirtModule::tryEvaluateFunctionsWithArgs(
@@ -2244,12 +2274,12 @@ void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) {
22442274
if (Res.TheKind == WholeProgramDevirtResolution::BranchFunnel) {
22452275
// The type of the function is irrelevant, because it's bitcast at calls
22462276
// anyhow.
2247-
Constant *JT = cast<Constant>(
2277+
auto *JT = cast<Function>(
22482278
M.getOrInsertFunction(getGlobalName(Slot, {}, "branch_funnel"),
22492279
Type::getVoidTy(M.getContext()))
22502280
.getCallee());
22512281
bool IsExported = false;
2252-
applyICallBranchFunnel(SlotInfo, JT, IsExported);
2282+
applyICallBranchFunnel(SlotInfo, *JT, IsExported);
22532283
assert(!IsExported);
22542284
}
22552285
}

llvm/test/Transforms/WholeProgramDevirt/branch-funnel.ll

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
; RUN: opt -passes=wholeprogramdevirt -whole-program-visibility -wholeprogramdevirt-summary-action=export -wholeprogramdevirt-read-summary=%S/Inputs/export.yaml -wholeprogramdevirt-write-summary=%t -S -o - %s | FileCheck --check-prefixes=CHECK,RETP %s
55

6-
; RUN: opt -passes='wholeprogramdevirt,default<O3>' -whole-program-visibility -wholeprogramdevirt-summary-action=export -wholeprogramdevirt-read-summary=%S/Inputs/export.yaml -wholeprogramdevirt-write-summary=%t -S -o - %s | FileCheck --check-prefixes=CHECK %s
6+
; RUN: opt -passes='wholeprogramdevirt,default<O3>' -whole-program-visibility -wholeprogramdevirt-summary-action=export -wholeprogramdevirt-read-summary=%S/Inputs/export.yaml -wholeprogramdevirt-write-summary=%t -S -o - %s | FileCheck --check-prefixes=CHECK,O3 %s
77

88
; RUN: FileCheck --check-prefix=SUMMARY %s < %t
99

@@ -159,7 +159,7 @@ declare ptr @llvm.load.relative.i32(ptr, i32)
159159

160160
; CHECK-LABEL: define i32 @fn1
161161
; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
162-
define i32 @fn1(ptr %obj) #0 {
162+
define i32 @fn1(ptr %obj) #0 !prof !10 {
163163
%vtable = load ptr, ptr %obj
164164
%p = call i1 @llvm.type.test(ptr %vtable, metadata !"typeid1")
165165
call void @llvm.assume(i1 %p)
@@ -172,7 +172,7 @@ define i32 @fn1(ptr %obj) #0 {
172172

173173
; CHECK-LABEL: define i32 @fn1_rv
174174
; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
175-
define i32 @fn1_rv(ptr %obj) #0 {
175+
define i32 @fn1_rv(ptr %obj) #0 !prof !10 {
176176
%vtable = load ptr, ptr %obj
177177
%p = call i1 @llvm.type.test(ptr %vtable, metadata !"typeid1_rv")
178178
call void @llvm.assume(i1 %p)
@@ -185,7 +185,7 @@ define i32 @fn1_rv(ptr %obj) #0 {
185185

186186
; CHECK-LABEL: define i32 @fn2
187187
; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
188-
define i32 @fn2(ptr %obj) #0 {
188+
define i32 @fn2(ptr %obj) #0 !prof !10 {
189189
%vtable = load ptr, ptr %obj
190190
%p = call i1 @llvm.type.test(ptr %vtable, metadata !"typeid2")
191191
call void @llvm.assume(i1 %p)
@@ -197,7 +197,7 @@ define i32 @fn2(ptr %obj) #0 {
197197

198198
; CHECK-LABEL: define i32 @fn2_rv
199199
; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
200-
define i32 @fn2_rv(ptr %obj) #0 {
200+
define i32 @fn2_rv(ptr %obj) #0 !prof !10 {
201201
%vtable = load ptr, ptr %obj
202202
%p = call i1 @llvm.type.test(ptr %vtable, metadata !"typeid2_rv")
203203
call void @llvm.assume(i1 %p)
@@ -209,7 +209,7 @@ define i32 @fn2_rv(ptr %obj) #0 {
209209

210210
; CHECK-LABEL: define i32 @fn3
211211
; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
212-
define i32 @fn3(ptr %obj) #0 {
212+
define i32 @fn3(ptr %obj) #0 !prof !10 {
213213
%vtable = load ptr, ptr %obj
214214
%p = call i1 @llvm.type.test(ptr %vtable, metadata !4)
215215
call void @llvm.assume(i1 %p)
@@ -222,7 +222,7 @@ define i32 @fn3(ptr %obj) #0 {
222222

223223
; CHECK-LABEL: define i32 @fn3_rv
224224
; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
225-
define i32 @fn3_rv(ptr %obj) #0 {
225+
define i32 @fn3_rv(ptr %obj) #0 !prof !10 {
226226
%vtable = load ptr, ptr %obj
227227
%p = call i1 @llvm.type.test(ptr %vtable, metadata !9)
228228
call void @llvm.assume(i1 %p)
@@ -235,7 +235,7 @@ define i32 @fn3_rv(ptr %obj) #0 {
235235

236236
; CHECK-LABEL: define i32 @fn4
237237
; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
238-
define i32 @fn4(ptr %obj) #0 {
238+
define i32 @fn4(ptr %obj) #0 !prof !10 {
239239
%p = call i1 @llvm.type.test(ptr @vt1_1, metadata !"typeid1")
240240
call void @llvm.assume(i1 %p)
241241
%fptr = load ptr, ptr @vt1_1
@@ -247,7 +247,7 @@ define i32 @fn4(ptr %obj) #0 {
247247

248248
; CHECK-LABEL: define i32 @fn4_cpy
249249
; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
250-
define i32 @fn4_cpy(ptr %obj) #0 {
250+
define i32 @fn4_cpy(ptr %obj) #0 !prof !10 {
251251
%p = call i1 @llvm.type.test(ptr @vt1_1, metadata !"typeid1")
252252
call void @llvm.assume(i1 %p)
253253
%fptr = load ptr, ptr @vt1_1
@@ -259,7 +259,7 @@ define i32 @fn4_cpy(ptr %obj) #0 {
259259

260260
; CHECK-LABEL: define i32 @fn4_rv
261261
; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
262-
define i32 @fn4_rv(ptr %obj) #0 {
262+
define i32 @fn4_rv(ptr %obj) #0 !prof !10 {
263263
%p = call i1 @llvm.type.test(ptr @vt1_1_rv, metadata !"typeid1_rv")
264264
call void @llvm.assume(i1 %p)
265265
%fptr = call ptr @llvm.load.relative.i32(ptr @vt1_1_rv, i32 0)
@@ -271,7 +271,7 @@ define i32 @fn4_rv(ptr %obj) #0 {
271271

272272
; CHECK-LABEL: define i32 @fn4_rv_cpy
273273
; CHECK-NOT: call void (...) @llvm.icall.branch.funnel
274-
define i32 @fn4_rv_cpy(ptr %obj) #0 {
274+
define i32 @fn4_rv_cpy(ptr %obj) #0 !prof !10 {
275275
%p = call i1 @llvm.type.test(ptr @vt1_1_rv, metadata !"typeid1_rv")
276276
call void @llvm.assume(i1 %p)
277277
%fptr = call ptr @llvm.load.relative.i32(ptr @vt1_1_rv, i32 0)
@@ -281,14 +281,18 @@ define i32 @fn4_rv_cpy(ptr %obj) #0 {
281281
ret i32 %result
282282
}
283283

284-
; CHECK-LABEL: define hidden void @__typeid_typeid1_0_branch_funnel(ptr nest %0, ...)
284+
; CHECK-LABEL: define hidden void @__typeid_typeid1_0_branch_funnel(ptr nest %0, ...) !prof !11
285285
; CHECK-NEXT: musttail call void (...) @llvm.icall.branch.funnel(ptr %0, ptr {{(nonnull )?}}@vt1_1, ptr {{(nonnull )?}}@vf1_1, ptr {{(nonnull )?}}@vt1_2, ptr {{(nonnull )?}}@vf1_2, ...)
286286

287-
; CHECK-LABEL: define hidden void @__typeid_typeid1_rv_0_branch_funnel(ptr nest %0, ...)
287+
; CHECK-LABEL: define hidden void @__typeid_typeid1_rv_0_branch_funnel(ptr nest %0, ...) !prof !11
288288
; CHECK-NEXT: musttail call void (...) @llvm.icall.branch.funnel(ptr %0, ptr {{(nonnull )?}}@vt1_1_rv, ptr {{(nonnull )?}}@vf1_1, ptr {{(nonnull )?}}@vt1_2_rv, ptr {{(nonnull )?}}@vf1_2, ...)
289289

290-
; CHECK: define internal void @branch_funnel(ptr
291-
; CHECK: define internal void @branch_funnel.1(ptr
290+
; CHECK: define internal void @branch_funnel(ptr {{.*}})
291+
; RETP-SAME !prof !10
292+
; NORETP-SAME !prof !11
293+
; CHECK: define internal void @branch_funnel.1(ptr {{.*}})
294+
; RETP-SAME !prof !10
295+
; NORETP-SAME !prof !11
292296

293297
declare i1 @llvm.type.test(ptr, metadata)
294298
declare void @llvm.assume(i1)
@@ -303,5 +307,6 @@ declare void @llvm.assume(i1)
303307
!7 = !{i32 0, !"typeid3_rv"}
304308
!8 = !{i32 0, !9}
305309
!9 = distinct !{}
310+
!10 = !{!"function_entry_count", i64 1000}
306311

307312
attributes #0 = { "target-features"="+retpoline" }

0 commit comments

Comments
 (0)