Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/BasicAliasAnalysis.h"
#include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/TypeMetadataUtils.h"
#include "llvm/Bitcode/BitcodeReader.h"
Expand All @@ -84,6 +85,7 @@
#include "llvm/IR/Module.h"
#include "llvm/IR/ModuleSummaryIndexYAML.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/ProfDataUtils.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Errc.h"
Expand All @@ -98,6 +100,7 @@
#include "llvm/Transforms/Utils/CallPromotionUtils.h"
#include "llvm/Transforms/Utils/Evaluator.h"
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <map>
#include <set>
Expand Down Expand Up @@ -170,6 +173,8 @@ static cl::list<std::string>
cl::desc("Prevent function(s) from being devirtualized"),
cl::Hidden, cl::CommaSeparated);

extern cl::opt<bool> ProfcheckDisableMetadataFixes;

/// With Clang, a pure virtual class's deleting destructor is emitted as a
/// `llvm.trap` intrinsic followed by an unreachable IR instruction. In the
/// context of whole program devirtualization, the deleting destructor of a pure
Expand Down Expand Up @@ -657,7 +662,7 @@ struct DevirtModule {
VTableSlotInfo &SlotInfo,
WholeProgramDevirtResolution *Res);

void applyICallBranchFunnel(VTableSlotInfo &SlotInfo, Constant *JT,
void applyICallBranchFunnel(VTableSlotInfo &SlotInfo, Function &JT,
bool &IsExported);
void tryICallBranchFunnel(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
VTableSlotInfo &SlotInfo,
Expand Down Expand Up @@ -1484,13 +1489,19 @@ void DevirtModule::tryICallBranchFunnel(
ReturnInst::Create(M.getContext(), nullptr, BB);

bool IsExported = false;
applyICallBranchFunnel(SlotInfo, JT, IsExported);
applyICallBranchFunnel(SlotInfo, *JT, IsExported);
if (IsExported)
Res->TheKind = WholeProgramDevirtResolution::BranchFunnel;

if (!JT->getEntryCount().has_value()) {
// FIXME: we could pass through thinlto the necessary information.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What info would we need?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

calling callsites' hotness

setExplicitlyUnknownFunctionEntryCount(*JT);
}
}

void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
Constant *JT, bool &IsExported) {
Function &JT, bool &IsExported) {
DenseMap<Function *, double> FunctionEntryCounts;
auto Apply = [&](CallSiteInfo &CSInfo) {
if (CSInfo.isExported())
IsExported = true;
Expand Down Expand Up @@ -1518,8 +1529,7 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,

NumBranchFunnel++;
if (RemarksEnabled)
VCallSite.emitRemark("branch-funnel",
JT->stripPointerCasts()->getName(), OREGetter);
VCallSite.emitRemark("branch-funnel", JT.getName(), OREGetter);

// Pass the address of the vtable in the nest register, which is r10 on
// x86_64.
Expand All @@ -1535,11 +1545,28 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
llvm::append_range(Args, CB.args());

CallBase *NewCS = nullptr;
if (!JT.isDeclaration() && !ProfcheckDisableMetadataFixes) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some comments about the approach used below to compute the entry counts?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ptal

// Accumulate the call frequencies of the original call site, and use
// that as total entry count for the funnel function.
auto &F = *CB.getCaller();
auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(F);
auto EC = BFI.getBlockFreq(&F.getEntryBlock());
auto CC = F.getEntryCount(/*AllowSynthetic=*/true);
double CallCount = 0.0;
if (EC.getFrequency() != 0 && CC && CC->getCount() != 0) {
double CallFreq =
static_cast<double>(
BFI.getBlockFreq(CB.getParent()).getFrequency()) /
EC.getFrequency();
CallCount = CallFreq * CC->getCount();
}
FunctionEntryCounts[&JT] += CallCount;
}
if (isa<CallInst>(CB))
NewCS = IRB.CreateCall(NewFT, JT, Args);
NewCS = IRB.CreateCall(NewFT, &JT, Args);
else
NewCS =
IRB.CreateInvoke(NewFT, JT, cast<InvokeInst>(CB).getNormalDest(),
IRB.CreateInvoke(NewFT, &JT, cast<InvokeInst>(CB).getNormalDest(),
cast<InvokeInst>(CB).getUnwindDest(), Args);
NewCS->setCallingConv(CB.getCallingConv());

Expand Down Expand Up @@ -1573,6 +1600,11 @@ void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo,
Apply(SlotInfo.CSInfo);
for (auto &P : SlotInfo.ConstCSInfo)
Apply(P.second);
for (auto &[F, C] : FunctionEntryCounts) {
assert(!F->getEntryCount(/*AllowSynthetic=*/true) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is false because this is a func that was initially set to have unknown earlier?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's set to unknown after this (==applyICallBranchFunnel) is called, if we didn't have module-internal callsites to use.

but to the original question, the functions appearing here are funnels we just generated, so should have no entry count.

"Unexpected entry count for funnel that was freshly synthesized");
F->setEntryCount(static_cast<uint64_t>(std::round(C)));
}
}

bool DevirtModule::tryEvaluateFunctionsWithArgs(
Expand Down Expand Up @@ -2246,12 +2278,12 @@ void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) {
if (Res.TheKind == WholeProgramDevirtResolution::BranchFunnel) {
// The type of the function is irrelevant, because it's bitcast at calls
// anyhow.
Constant *JT = cast<Constant>(
auto *JT = cast<Function>(
M.getOrInsertFunction(getGlobalName(Slot, {}, "branch_funnel"),
Type::getVoidTy(M.getContext()))
.getCallee());
bool IsExported = false;
applyICallBranchFunnel(SlotInfo, JT, IsExported);
applyICallBranchFunnel(SlotInfo, *JT, IsExported);
assert(!IsExported);
}
}
Expand Down
203 changes: 203 additions & 0 deletions llvm/test/Transforms/WholeProgramDevirt/branch-funnel-profile.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
; A variant of branch-funnel.ll where we just check that the funnels' entry counts
; are correctly set.
;
; RUN: opt -S -passes=wholeprogramdevirt -whole-program-visibility %s | FileCheck --check-prefixes=RETP %s
; RUN: sed -e 's,+retpoline,-retpoline,g' %s | opt -S -passes=wholeprogramdevirt -whole-program-visibility | FileCheck --check-prefixes=NORETP %s
; RUN: opt -passes=wholeprogramdevirt -whole-program-visibility -wholeprogramdevirt-summary-action=export -wholeprogramdevirt-read-summary=%S/Inputs/export.yaml -wholeprogramdevirt-write-summary=%t -S -o - %s | FileCheck --check-prefixes=RETP %s
; RUN: opt -passes='wholeprogramdevirt,default<O3>' -whole-program-visibility -wholeprogramdevirt-summary-action=export -wholeprogramdevirt-read-summary=%S/Inputs/export.yaml -wholeprogramdevirt-write-summary=%t -S -o - %s | FileCheck --check-prefixes=O3 %s

; RETP: define hidden void @__typeid_typeid1_0_branch_funnel(ptr nest %0, ...) !prof !11
; RETP: define hidden void @__typeid_typeid1_rv_0_branch_funnel(ptr nest %0, ...) !prof !11
; RETP: define internal void @branch_funnel(ptr nest %0, ...) !prof !10
; RETP: define internal void @branch_funnel.1(ptr nest %0, ...) !prof !10
; RETP: !10 = !{!"function_entry_count", i64 1000}
; RETP: !11 = !{!"function_entry_count", i64 3000}

; NORETP: define hidden void @__typeid_typeid1_0_branch_funnel(ptr nest %0, ...) !prof !11
; NORETP: define hidden void @__typeid_typeid1_rv_0_branch_funnel(ptr nest %0, ...) !prof !11
; NORETP: define internal void @branch_funnel(ptr nest %0, ...) !prof !11
; NORETP: define internal void @branch_funnel.1(ptr nest %0, ...) !prof !11
; NORETP: !11 = !{!"unknown"}

; O3: define hidden void @__typeid_typeid1_0_branch_funnel(ptr nest %0, ...) local_unnamed_addr #5 !prof !11
; O3: define hidden void @__typeid_typeid1_rv_0_branch_funnel(ptr nest %0, ...) local_unnamed_addr #5 !prof !11
; O3: define internal void @branch_funnel(ptr nest %0, ...) unnamed_addr #5 !prof !10
; O3: define internal void @branch_funnel.1(ptr nest %0, ...) unnamed_addr #5 !prof !10
; O3: define hidden void @__typeid_typeid3_0_branch_funnel(ptr nest %0, ...) local_unnamed_addr #5 !prof !12
; O3: define hidden void @__typeid_typeid3_rv_0_branch_funnel(ptr nest %0, ...) local_unnamed_addr #5 !prof !12
; O3: !10 = !{!"function_entry_count", i64 1000}
; O3: !11 = !{!"function_entry_count", i64 3000}
; O3: !12 = !{!"unknown"}

target datalayout = "e-p:64:64"
target triple = "x86_64-unknown-linux-gnu"

@vt1_1 = constant [1 x ptr] [ptr @vf1_1], !type !0
@vt1_2 = constant [1 x ptr] [ptr @vf1_2], !type !0

declare i32 @vf1_1(ptr %this, i32 %arg)
declare i32 @vf1_2(ptr %this, i32 %arg)

@vt2_1 = constant [1 x ptr] [ptr @vf2_1], !type !1
@vt2_2 = constant [1 x ptr] [ptr @vf2_2], !type !1
@vt2_3 = constant [1 x ptr] [ptr @vf2_3], !type !1
@vt2_4 = constant [1 x ptr] [ptr @vf2_4], !type !1
@vt2_5 = constant [1 x ptr] [ptr @vf2_5], !type !1
@vt2_6 = constant [1 x ptr] [ptr @vf2_6], !type !1
@vt2_7 = constant [1 x ptr] [ptr @vf2_7], !type !1
@vt2_8 = constant [1 x ptr] [ptr @vf2_8], !type !1
@vt2_9 = constant [1 x ptr] [ptr @vf2_9], !type !1
@vt2_10 = constant [1 x ptr] [ptr @vf2_10], !type !1
@vt2_11 = constant [1 x ptr] [ptr @vf2_11], !type !1

declare i32 @vf2_1(ptr %this, i32 %arg)
declare i32 @vf2_2(ptr %this, i32 %arg)
declare i32 @vf2_3(ptr %this, i32 %arg)
declare i32 @vf2_4(ptr %this, i32 %arg)
declare i32 @vf2_5(ptr %this, i32 %arg)
declare i32 @vf2_6(ptr %this, i32 %arg)
declare i32 @vf2_7(ptr %this, i32 %arg)
declare i32 @vf2_8(ptr %this, i32 %arg)
declare i32 @vf2_9(ptr %this, i32 %arg)
declare i32 @vf2_10(ptr %this, i32 %arg)
declare i32 @vf2_11(ptr %this, i32 %arg)

@vt3_1 = constant [1 x ptr] [ptr @vf3_1], !type !2
@vt3_2 = constant [1 x ptr] [ptr @vf3_2], !type !2

declare i32 @vf3_1(ptr %this, i32 %arg)
declare i32 @vf3_2(ptr %this, i32 %arg)

@vt4_1 = constant [1 x ptr] [ptr @vf4_1], !type !3
@vt4_2 = constant [1 x ptr] [ptr @vf4_2], !type !3

declare i32 @vf4_1(ptr %this, i32 %arg)
declare i32 @vf4_2(ptr %this, i32 %arg)

declare ptr @llvm.load.relative.i32(ptr, i32)

;; These are relative vtables equivalent to the ones above.
@vt1_1_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf1_1 to i64), i64 ptrtoint (ptr @vt1_1_rv to i64)) to i32)], !type !5
@vt1_2_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf1_2 to i64), i64 ptrtoint (ptr @vt1_2_rv to i64)) to i32)], !type !5

@vt2_1_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf2_1 to i64), i64 ptrtoint (ptr @vt2_1_rv to i64)) to i32)], !type !6
@vt2_2_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf2_2 to i64), i64 ptrtoint (ptr @vt2_2_rv to i64)) to i32)], !type !6
@vt2_3_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf2_3 to i64), i64 ptrtoint (ptr @vt2_3_rv to i64)) to i32)], !type !6
@vt2_4_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf2_4 to i64), i64 ptrtoint (ptr @vt2_4_rv to i64)) to i32)], !type !6
@vt2_5_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf2_5 to i64), i64 ptrtoint (ptr @vt2_5_rv to i64)) to i32)], !type !6
@vt2_6_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf2_6 to i64), i64 ptrtoint (ptr @vt2_6_rv to i64)) to i32)], !type !6
@vt2_7_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf2_7 to i64), i64 ptrtoint (ptr @vt2_7_rv to i64)) to i32)], !type !6
@vt2_8_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf2_8 to i64), i64 ptrtoint (ptr @vt2_8_rv to i64)) to i32)], !type !6
@vt2_9_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf2_9 to i64), i64 ptrtoint (ptr @vt2_9_rv to i64)) to i32)], !type !6
@vt2_10_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf2_10 to i64), i64 ptrtoint (ptr @vt2_10_rv to i64)) to i32)], !type !6
@vt2_11_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf2_11 to i64), i64 ptrtoint (ptr @vt2_11_rv to i64)) to i32)], !type !6

@vt3_1_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf3_1 to i64), i64 ptrtoint (ptr @vt3_1_rv to i64)) to i32)], !type !7
@vt3_2_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf3_2 to i64), i64 ptrtoint (ptr @vt3_2_rv to i64)) to i32)], !type !7

@vt4_1_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf4_1 to i64), i64 ptrtoint (ptr @vt4_1_rv to i64)) to i32)], !type !8
@vt4_2_rv = constant [1 x i32] [i32 trunc (i64 sub (i64 ptrtoint (ptr dso_local_equivalent @vf4_2 to i64), i64 ptrtoint (ptr @vt4_2_rv to i64)) to i32)], !type !8


define i32 @fn1(ptr %obj) #0 !prof !10 {
%vtable = load ptr, ptr %obj
%p = call i1 @llvm.type.test(ptr %vtable, metadata !"typeid1")
call void @llvm.assume(i1 %p)
%fptr = load ptr, ptr %vtable
%result = call i32 %fptr(ptr %obj, i32 1)
ret i32 %result
}

define i32 @fn1_rv(ptr %obj) #0 !prof !10 {
%vtable = load ptr, ptr %obj
%p = call i1 @llvm.type.test(ptr %vtable, metadata !"typeid1_rv")
call void @llvm.assume(i1 %p)
%fptr = call ptr @llvm.load.relative.i32(ptr %vtable, i32 0)
%result = call i32 %fptr(ptr %obj, i32 1)
ret i32 %result
}

define i32 @fn2(ptr %obj) #0 !prof !10 {
%vtable = load ptr, ptr %obj
%p = call i1 @llvm.type.test(ptr %vtable, metadata !"typeid2")
call void @llvm.assume(i1 %p)
%fptr = load ptr, ptr %vtable
%result = call i32 %fptr(ptr %obj, i32 1)
ret i32 %result
}

define i32 @fn2_rv(ptr %obj) #0 !prof !10 {
%vtable = load ptr, ptr %obj
%p = call i1 @llvm.type.test(ptr %vtable, metadata !"typeid2_rv")
call void @llvm.assume(i1 %p)
%fptr = call ptr @llvm.load.relative.i32(ptr %vtable, i32 0)
%result = call i32 %fptr(ptr %obj, i32 1)
ret i32 %result
}

define i32 @fn3(ptr %obj) #0 !prof !10 {
%vtable = load ptr, ptr %obj
%p = call i1 @llvm.type.test(ptr %vtable, metadata !4)
call void @llvm.assume(i1 %p)
%fptr = load ptr, ptr %vtable
%result = call i32 %fptr(ptr %obj, i32 1)
ret i32 %result
}

define i32 @fn3_rv(ptr %obj) #0 !prof !10 {
%vtable = load ptr, ptr %obj
%p = call i1 @llvm.type.test(ptr %vtable, metadata !9)
call void @llvm.assume(i1 %p)
%fptr = call ptr @llvm.load.relative.i32(ptr %vtable, i32 0)
%result = call i32 %fptr(ptr %obj, i32 1)
ret i32 %result
}

define i32 @fn4(ptr %obj) #0 !prof !10 {
%p = call i1 @llvm.type.test(ptr @vt1_1, metadata !"typeid1")
call void @llvm.assume(i1 %p)
%fptr = load ptr, ptr @vt1_1
%result = call i32 %fptr(ptr %obj, i32 1)
ret i32 %result
}

define i32 @fn4_cpy(ptr %obj) #0 !prof !10 {
%p = call i1 @llvm.type.test(ptr @vt1_1, metadata !"typeid1")
call void @llvm.assume(i1 %p)
%fptr = load ptr, ptr @vt1_1
%result = call i32 %fptr(ptr %obj, i32 1)
ret i32 %result
}

define i32 @fn4_rv(ptr %obj) #0 !prof !10 {
%p = call i1 @llvm.type.test(ptr @vt1_1_rv, metadata !"typeid1_rv")
call void @llvm.assume(i1 %p)
%fptr = call ptr @llvm.load.relative.i32(ptr @vt1_1_rv, i32 0)
%result = call i32 %fptr(ptr %obj, i32 1)
ret i32 %result
}

define i32 @fn4_rv_cpy(ptr %obj) #0 !prof !10 {
%p = call i1 @llvm.type.test(ptr @vt1_1_rv, metadata !"typeid1_rv")
call void @llvm.assume(i1 %p)
%fptr = call ptr @llvm.load.relative.i32(ptr @vt1_1_rv, i32 0)
%result = call i32 %fptr(ptr %obj, i32 1)
ret i32 %result
}

declare i1 @llvm.type.test(ptr, metadata)
declare void @llvm.assume(i1)

!0 = !{i32 0, !"typeid1"}
!1 = !{i32 0, !"typeid2"}
!2 = !{i32 0, !"typeid3"}
!3 = !{i32 0, !4}
!4 = distinct !{}
!5 = !{i32 0, !"typeid1_rv"}
!6 = !{i32 0, !"typeid2_rv"}
!7 = !{i32 0, !"typeid3_rv"}
!8 = !{i32 0, !9}
!9 = distinct !{}
!10 = !{!"function_entry_count", i64 1000}

attributes #0 = { "target-features"="+retpoline" }