Skip to content

Commit 3097688

Browse files
authored
[SimplifyCFG] Set branch weights when merging conditional store to address (#154841)
1 parent ca09801 commit 3097688

File tree

3 files changed

+54
-5
lines changed

3 files changed

+54
-5
lines changed

llvm/include/llvm/IR/ProfDataUtils.h

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#define LLVM_IR_PROFDATAUTILS_H
1717

1818
#include "llvm/ADT/SmallVector.h"
19-
#include "llvm/ADT/Twine.h"
2019
#include "llvm/IR/Metadata.h"
2120
#include "llvm/Support/Compiler.h"
2221

@@ -197,5 +196,33 @@ LLVM_ABI bool hasExplicitlyUnknownBranchWeights(const Instruction &I);
197196
/// Scaling the profile data attached to 'I' using the ratio of S/T.
198197
LLVM_ABI void scaleProfData(Instruction &I, uint64_t S, uint64_t T);
199198

199+
/// Get the branch weights of a branch conditioned on b1 || b2, where b1 and b2
200+
/// are 2 booleans that are the conditions of 2 branches for which we have the
201+
/// branch weights B1 and B2, respectively. In both B1 and B2, the first
202+
/// position (index 0) is for the 'true' branch, and the second position (index
203+
/// 1) is for the 'false' branch.
204+
inline SmallVector<uint64_t, 2>
205+
getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
206+
const SmallVector<uint32_t, 2> &B2) {
207+
// For the first conditional branch, the probability the "true" case is taken
208+
// is p(b1) = B1[0] / (B1[0] + B1[1]). The "false" case's probability is
209+
// p(not b1) = B1[1] / (B1[0] + B1[1]).
210+
// Similarly for the second conditional branch and B2.
211+
//
212+
// The probability of the new branch NOT being taken is:
213+
// not P = p((not b1) and (not b2)) =
214+
// = B1[1] / (B1[0]+B1[1]) * B2[1] / (B2[0]+B2[1]) =
215+
// = B1[1] * B2[1] / (B1[0] + B1[1]) * (B2[0] + B2[1])
216+
// Then the probability of it being taken is: P = 1 - (not P).
217+
// The denominator will be the same as above, and the numerator of P will be:
218+
// (B1[0] + B1[1]) * (B2[0] + B2[1]) - B1[1]*B2[1]
219+
// Which then reduces to what's shown below (out of the 4 terms coming out of
220+
// the product of sums, the subtracted one cancels out).
221+
assert(B1.size() == 2);
222+
assert(B2.size() == 2);
223+
auto FalseWeight = B1[1] * B2[1];
224+
auto TrueWeight = B1[0] * B2[0] + B1[0] * B2[1] + B1[1] * B2[0];
225+
return {TrueWeight, FalseWeight};
226+
}
200227
} // namespace llvm
201228
#endif

llvm/lib/Transforms/Utils/SimplifyCFG.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ static cl::opt<unsigned> MaxJumpThreadingLiveBlocks(
203203
cl::desc("Limit number of blocks a define in a threaded block is allowed "
204204
"to be live in"));
205205

206+
extern cl::opt<bool> ProfcheckDisableMetadataFixes;
207+
206208
STATISTIC(NumBitMaps, "Number of switch instructions turned into bitmaps");
207209
STATISTIC(NumLinearMaps,
208210
"Number of switch instructions turned into linear mapping");
@@ -4438,6 +4440,20 @@ static bool mergeConditionalStoreToAddress(
44384440
auto *T = SplitBlockAndInsertIfThen(CombinedPred, InsertPt,
44394441
/*Unreachable=*/false,
44404442
/*BranchWeights=*/nullptr, DTU);
4443+
if (hasBranchWeightMD(*PBranch) && hasBranchWeightMD(*QBranch) &&
4444+
!ProfcheckDisableMetadataFixes) {
4445+
SmallVector<uint32_t, 2> PWeights, QWeights;
4446+
extractBranchWeights(*PBranch, PWeights);
4447+
extractBranchWeights(*QBranch, QWeights);
4448+
if (InvertPCond)
4449+
std::swap(PWeights[0], PWeights[1]);
4450+
if (InvertQCond)
4451+
std::swap(QWeights[0], QWeights[1]);
4452+
auto CombinedWeights = getDisjunctionWeights(PWeights, QWeights);
4453+
setBranchWeights(PostBB->getTerminator(), CombinedWeights[0],
4454+
CombinedWeights[1],
4455+
/*IsExpected=*/false);
4456+
}
44414457

44424458
QB.SetInsertPoint(T);
44434459
StoreInst *SI = cast<StoreInst>(QB.CreateStore(QPHI, Address));

llvm/test/Transforms/SimplifyCFG/merge-cond-stores.ll

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals
22
; RUN: opt -passes=simplifycfg,instcombine -simplifycfg-require-and-preserve-domtree=1 < %s -simplifycfg-merge-cond-stores=true -simplifycfg-merge-cond-stores-aggressively=false -phi-node-folding-threshold=2 -S | FileCheck %s
33

44
; This test should succeed and end up if-converted.
@@ -43,7 +43,7 @@ define void @test_simple_commuted(ptr %p, i32 %a, i32 %b) {
4343
; CHECK-NEXT: [[X2:%.*]] = icmp eq i32 [[B:%.*]], 0
4444
; CHECK-NEXT: [[X3:%.*]] = icmp eq i32 [[B1:%.*]], 0
4545
; CHECK-NEXT: [[TMP0:%.*]] = or i1 [[X2]], [[X3]]
46-
; CHECK-NEXT: br i1 [[TMP0]], label [[TMP1:%.*]], label [[TMP2:%.*]]
46+
; CHECK-NEXT: br i1 [[TMP0]], label [[TMP1:%.*]], label [[TMP2:%.*]], !prof [[PROF0:![0-9]+]]
4747
; CHECK: 1:
4848
; CHECK-NEXT: [[SPEC_SELECT:%.*]] = zext i1 [[X3]] to i32
4949
; CHECK-NEXT: store i32 [[SPEC_SELECT]], ptr [[P:%.*]], align 4
@@ -53,15 +53,15 @@ define void @test_simple_commuted(ptr %p, i32 %a, i32 %b) {
5353
;
5454
entry:
5555
%x1 = icmp eq i32 %a, 0
56-
br i1 %x1, label %yes1, label %fallthrough
56+
br i1 %x1, label %yes1, label %fallthrough, !prof !0
5757

5858
yes1:
5959
store i32 0, ptr %p
6060
br label %fallthrough
6161

6262
fallthrough:
6363
%x2 = icmp eq i32 %b, 0
64-
br i1 %x2, label %yes2, label %end
64+
br i1 %x2, label %yes2, label %end, !prof !1
6565

6666
yes2:
6767
store i32 1, ptr %p
@@ -406,3 +406,9 @@ yes2:
406406
end:
407407
ret void
408408
}
409+
410+
!0 = !{!"branch_weights", i32 7, i32 13}
411+
!1 = !{!"branch_weights", i32 3, i32 11}
412+
;.
413+
; CHECK: [[PROF0]] = !{!"branch_weights", i32 137, i32 143}
414+
;.

0 commit comments

Comments
 (0)