Skip to content

Commit 2bc1823

Browse files
committed
[LTT][profcheck] Set branch weights for complex llvm.type.test lowering
1 parent aeacb0b commit 2bc1823

File tree

3 files changed

+60
-13
lines changed

3 files changed

+60
-13
lines changed

llvm/lib/Transforms/IPO/LowerTypeTests.cpp

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "llvm/ADT/Statistic.h"
2626
#include "llvm/ADT/StringRef.h"
2727
#include "llvm/ADT/TinyPtrVector.h"
28+
#include "llvm/Analysis/BlockFrequencyInfo.h"
2829
#include "llvm/Analysis/LoopInfo.h"
2930
#include "llvm/Analysis/PostDominators.h"
3031
#include "llvm/Analysis/TargetTransformInfo.h"
@@ -54,6 +55,7 @@
5455
#include "llvm/IR/ModuleSummaryIndexYAML.h"
5556
#include "llvm/IR/Operator.h"
5657
#include "llvm/IR/PassManager.h"
58+
#include "llvm/IR/ProfDataUtils.h"
5759
#include "llvm/IR/ReplaceConstant.h"
5860
#include "llvm/IR/Type.h"
5961
#include "llvm/IR/Use.h"
@@ -95,6 +97,7 @@ STATISTIC(NumByteArraysCreated, "Number of byte arrays created");
9597
STATISTIC(NumTypeTestCallsLowered, "Number of type test calls lowered");
9698
STATISTIC(NumTypeIdDisjointSets, "Number of disjoint sets of type identifiers");
9799

100+
namespace llvm {
98101
static cl::opt<bool> AvoidReuse(
99102
"lowertypetests-avoid-reuse",
100103
cl::desc("Try to avoid reuse of byte array addresses using aliases"),
@@ -131,6 +134,9 @@ static cl::opt<DropTestKind>
131134
"Drop all type test sequences")),
132135
cl::Hidden, cl::init(DropTestKind::None));
133136

137+
extern cl::opt<bool> ProfcheckDisableMetadataFixes;
138+
} // namespace llvm
139+
134140
bool BitSetInfo::containsGlobalOffset(uint64_t Offset) const {
135141
if (Offset < ByteOffset)
136142
return false;
@@ -423,8 +429,10 @@ struct ScopedSaveAliaseesAndUsed {
423429
class LowerTypeTestsModule {
424430
Module &M;
425431

426-
ModuleSummaryIndex *ExportSummary;
427-
const ModuleSummaryIndex *ImportSummary;
432+
FunctionAnalysisManager &FAM;
433+
434+
ModuleSummaryIndex *const ExportSummary;
435+
const ModuleSummaryIndex *const ImportSummary;
428436
// Set when the client has invoked this to simply drop all type test assume
429437
// sequences.
430438
DropTestKind DropTypeTests;
@@ -507,9 +515,10 @@ class LowerTypeTestsModule {
507515
void allocateByteArrays();
508516
Value *createBitSetTest(IRBuilder<> &B, const TypeIdLowering &TIL,
509517
Value *BitOffset);
510-
void lowerTypeTestCalls(
511-
ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr,
512-
const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout);
518+
void
519+
lowerTypeTestCalls(ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr,
520+
const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout,
521+
uint64_t *TotalCallCount = nullptr);
513522
Value *lowerTypeTestCall(Metadata *TypeId, CallInst *CI,
514523
const TypeIdLowering &TIL);
515524

@@ -803,6 +812,8 @@ Value *LowerTypeTestsModule::lowerTypeTestCall(Metadata *TypeId, CallInst *CI,
803812
}
804813

805814
IRBuilder<> ThenB(SplitBlockAndInsertIfThen(OffsetInRange, CI, false));
815+
setExplicitlyUnknownBranchWeightsIfProfiled(*InitialBB->getTerminator(),
816+
DEBUG_TYPE);
806817

807818
// Now that we know that the offset is in range and aligned, load the
808819
// appropriate bit from the bitset.
@@ -1181,7 +1192,8 @@ buildBitSets(ArrayRef<Metadata *> TypeIds,
11811192

11821193
void LowerTypeTestsModule::lowerTypeTestCalls(
11831194
ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr,
1184-
const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout) {
1195+
const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout,
1196+
uint64_t *TotalCallCount) {
11851197
// For each type identifier in this disjoint set...
11861198
for (const auto &[TypeId, BSI] : buildBitSets(TypeIds, GlobalLayout)) {
11871199
ByteArrayInfo *BAI = nullptr;
@@ -1227,6 +1239,18 @@ void LowerTypeTestsModule::lowerTypeTestCalls(
12271239
++NumTypeTestCallsLowered;
12281240
Value *Lowered = lowerTypeTestCall(TypeId, CI, TIL);
12291241
if (Lowered) {
1242+
if (TotalCallCount) {
1243+
auto *CIF = CI->getFunction();
1244+
if (auto EC = CIF->getEntryCount())
1245+
if (EC->getCount()) {
1246+
auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*CIF);
1247+
*TotalCallCount +=
1248+
EC->getCount() *
1249+
static_cast<double>(
1250+
BFI.getBlockFreq(CI->getParent()).getFrequency()) /
1251+
BFI.getEntryFreq().getFrequency();
1252+
}
1253+
}
12301254
CI->replaceAllUsesWith(Lowered);
12311255
CI->eraseFromParent();
12321256
}
@@ -1702,10 +1726,13 @@ void LowerTypeTestsModule::buildBitSetsFromFunctionsNative(
17021726
ArrayType *JumpTableEntryType = ArrayType::get(Int8Ty, EntrySize);
17031727
ArrayType *JumpTableType =
17041728
ArrayType::get(JumpTableEntryType, Functions.size());
1705-
auto JumpTable = ConstantExpr::getPointerCast(
1729+
auto *JumpTable = ConstantExpr::getPointerCast(
17061730
JumpTableFn, PointerType::getUnqual(M.getContext()));
17071731

1708-
lowerTypeTestCalls(TypeIds, JumpTable, GlobalLayout);
1732+
uint64_t Count = 0;
1733+
lowerTypeTestCalls(TypeIds, JumpTable, GlobalLayout, &Count);
1734+
if (!ProfcheckDisableMetadataFixes && Count)
1735+
JumpTableFn->setEntryCount(Count);
17091736

17101737
// Build aliases pointing to offsets into the jump table, and replace
17111738
// references to the original functions with references to the aliases.
@@ -1870,7 +1897,9 @@ void LowerTypeTestsModule::buildBitSetsFromDisjointSet(
18701897
LowerTypeTestsModule::LowerTypeTestsModule(
18711898
Module &M, ModuleAnalysisManager &AM, ModuleSummaryIndex *ExportSummary,
18721899
const ModuleSummaryIndex *ImportSummary, DropTestKind DropTypeTests)
1873-
: M(M), ExportSummary(ExportSummary), ImportSummary(ImportSummary),
1900+
: M(M),
1901+
FAM(AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager()),
1902+
ExportSummary(ExportSummary), ImportSummary(ImportSummary),
18741903
DropTypeTests(ClDropTypeTests > DropTypeTests ? ClDropTypeTests
18751904
: DropTypeTests) {
18761905
assert(!(ExportSummary && ImportSummary));
@@ -1879,8 +1908,6 @@ LowerTypeTestsModule::LowerTypeTestsModule(
18791908
if (Arch == Triple::arm)
18801909
CanUseArmJumpTable = true;
18811910
if (Arch == Triple::arm || Arch == Triple::thumb) {
1882-
auto &FAM =
1883-
AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
18841911
for (Function &F : M) {
18851912
// Skip declarations since we should not query the TTI for them.
18861913
if (F.isDeclaration())

llvm/test/Other/new-pm-O0-defaults.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
; CHECK-PRE-LINK: Running pass: CanonicalizeAliasesPass
4545
; CHECK-PRE-LINK-NEXT: Running pass: NameAnonGlobalPass
4646
; CHECK-THINLTO: Running pass: LowerTypeTestsPass
47+
; CHECK-THINLTO: Running analysis: InnerAnalysisManagerProxy<FunctionAnalysisManager, Module> on [module]
4748
; CHECK-THINLTO-NEXT: Running pass: CoroConditionalWrapper
4849
; CHECK-THINLTO-NEXT: Running pass: EliminateAvailableExternallyPass
4950
; CHECK-THINLTO-NEXT: Running pass: GlobalDCEPass

llvm/test/Transforms/LowerTypeTests/section.ll

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,33 @@ entry:
1313
ret void
1414
}
1515

16-
define i1 @g() {
16+
define i1 @g() !prof !1 {
1717
entry:
1818
%0 = call i1 @llvm.type.test(ptr @f, metadata !"_ZTSFvE")
1919
ret i1 %0
2020
}
2121

22-
; CHECK: define private void @[[JT]]() #{{.*}} align {{.*}} {
22+
define i1 @h(i1 %c) !prof !2 {
23+
entry:
24+
br i1 %c, label %yes, label %common, !prof !3
25+
26+
yes:
27+
%0 = call i1 @llvm.type.test(ptr @f, metadata !"_ZTSFvE")
28+
ret i1 %0
29+
30+
common:
31+
ret i1 0
32+
}
33+
34+
; CHECK: define private void @[[JT]]() #{{.*}} align {{.*}} !prof !4 {
2335

2436
declare i1 @llvm.type.test(ptr, metadata) nounwind readnone
2537

2638
!0 = !{i64 0, !"_ZTSFvE"}
39+
!1 = !{!"function_entry_count", i32 20}
40+
!2 = !{!"function_entry_count", i32 40}
41+
!3 = !{!"branch_weights", i32 3, i32 5}
42+
; the entry count for the jumptable function is: 20 + 40 * (3/8) = 20 + 15
43+
; where: 20 is the entry count of g, 40 of h, and 3/8 is the frequency of the
44+
; llvm.type.test in h, relative to h's entry basic block.
45+
; CHECK !4 = !{!"function_entry_count", i64 35}

0 commit comments

Comments
 (0)