-
Notifications
You must be signed in to change notification settings - Fork 15k
[SimplifyCFG] Set branch weights when merging conditional store to address #154841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SimplifyCFG] Set branch weights when merging conditional store to address #154841
Conversation
This stack of pull requests is managed by Graphite. Learn more about stacking. |
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
f4441cb to
c77048c
Compare
7be4626 to
18cf46f
Compare
c77048c to
975b3e3
Compare
975b3e3 to
24507f7
Compare
d6bf02f to
d52a459
Compare
|
@llvm/pr-subscribers-llvm-ir @llvm/pr-subscribers-llvm-transforms Author: Mircea Trofin (mtrofin) ChangesFull diff: https://github.com/llvm/llvm-project/pull/154841.diff 3 Files Affected:
diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h
index 404875285beae..ebf8559cd3d91 100644
--- a/llvm/include/llvm/IR/ProfDataUtils.h
+++ b/llvm/include/llvm/IR/ProfDataUtils.h
@@ -15,6 +15,7 @@
#ifndef LLVM_IR_PROFDATAUTILS_H
#define LLVM_IR_PROFDATAUTILS_H
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/IR/Metadata.h"
@@ -186,5 +187,31 @@ LLVM_ABI bool hasExplicitlyUnknownBranchWeights(const Instruction &I);
/// Scaling the profile data attached to 'I' using the ratio of S/T.
LLVM_ABI void scaleProfData(Instruction &I, uint64_t S, uint64_t T);
+/// get the branch weights of a branch conditioned on b1 || b2, where b1 and b2
+/// are 2 booleans that are the condition of 2 branches for which we have the
+/// branch weights B1 and B2, respectivelly.
+inline SmallVector<uint64_t, 2>
+getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
+ const SmallVector<uint32_t, 2> &B2) {
+ // for the first conditional branch, the probability the "true" case is taken
+ // is p(b1) = B1[0] / (B1[0] + B2[0]). The "false" case's probability is
+ // p(not b1) = B1[1] / (B1[0] + B1[1]).
+ // Similarly for the second conditional branch and B2.
+ //
+ // the probability of the new branch NOT being taken is:
+ // not P = p((not b1) and (not b2)) =
+ // = B1[1] / (B1[0]+B1[1]) * B2[1] / (B2[0]+B2[1]) =
+ // = B1[1] * B2[1] / (B1[0] + B1[1]) * (B2[0] + B2[1])
+ // then the probability of it being taken is: P = 1 - (not P).
+ // The denominator will be the same as above, and the numerator of P will be
+ // (B1[0] + B1[1]) * (B2[0] + B2[1]) - B1[1]*B2[1]
+ // Which then reduces to what's shown below (out of the 4 terms coming out of
+ // the product of sums, the subtracted one cancels out)
+ assert(B1.size() == 2);
+ assert(B2.size() == 2);
+ auto FalseWeight = B1[1] * B2[1];
+ auto TrueWeight = B1[0] * B2[0] + B1[0] * B2[1] + B1[1] * B2[0];
+ return {TrueWeight, FalseWeight};
+}
} // namespace llvm
#endif
diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
index 270598e2b674b..370b282d1b14d 100644
--- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -203,6 +203,8 @@ static cl::opt<unsigned> MaxJumpThreadingLiveBlocks(
cl::desc("Limit number of blocks a define in a threaded block is allowed "
"to be live in"));
+extern cl::opt<bool> ProfcheckDisableMetadataFixes;
+
STATISTIC(NumBitMaps, "Number of switch instructions turned into bitmaps");
STATISTIC(NumLinearMaps,
"Number of switch instructions turned into linear mapping");
@@ -4431,6 +4433,20 @@ static bool mergeConditionalStoreToAddress(
auto *T = SplitBlockAndInsertIfThen(CombinedPred, InsertPt,
/*Unreachable=*/false,
/*BranchWeights=*/nullptr, DTU);
+ if (hasBranchWeightMD(*PBranch) && hasBranchWeightMD(*QBranch) &&
+ !ProfcheckDisableMetadataFixes) {
+ SmallVector<uint32_t, 2> PWeights, QWeights;
+ extractBranchWeights(*PBranch, PWeights);
+ extractBranchWeights(*QBranch, QWeights);
+ if (InvertPCond)
+ std::swap(PWeights[0], PWeights[1]);
+ if (InvertQCond)
+ std::swap(QWeights[0], QWeights[1]);
+ auto CombinedWeights = getDisjunctionWeights(PWeights, QWeights);
+ setBranchWeights(PostBB->getTerminator(), CombinedWeights[0],
+ CombinedWeights[1],
+ /*IsExpected=*/false);
+ }
QB.SetInsertPoint(T);
StoreInst *SI = cast<StoreInst>(QB.CreateStore(QPHI, Address));
diff --git a/llvm/test/Transforms/SimplifyCFG/merge-cond-stores.ll b/llvm/test/Transforms/SimplifyCFG/merge-cond-stores.ll
index b5c4b8aa51db4..ee723463d4b06 100644
--- a/llvm/test/Transforms/SimplifyCFG/merge-cond-stores.ll
+++ b/llvm/test/Transforms/SimplifyCFG/merge-cond-stores.ll
@@ -1,4 +1,4 @@
-; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals
; RUN: opt -passes=simplifycfg,instcombine -simplifycfg-require-and-preserve-domtree=1 < %s -simplifycfg-merge-cond-stores=true -simplifycfg-merge-cond-stores-aggressively=false -phi-node-folding-threshold=2 -S | FileCheck %s
; This test should succeed and end up if-converted.
@@ -43,7 +43,7 @@ define void @test_simple_commuted(ptr %p, i32 %a, i32 %b) {
; CHECK-NEXT: [[X1_NOT:%.*]] = icmp eq i32 [[A:%.*]], 0
; CHECK-NEXT: [[X2:%.*]] = icmp eq i32 [[B:%.*]], 0
; CHECK-NEXT: [[TMP0:%.*]] = or i1 [[X1_NOT]], [[X2]]
-; CHECK-NEXT: br i1 [[TMP0]], label [[TMP1:%.*]], label [[TMP2:%.*]]
+; CHECK-NEXT: br i1 [[TMP0]], label [[TMP1:%.*]], label [[TMP2:%.*]], !prof [[PROF0:![0-9]+]]
; CHECK: 1:
; CHECK-NEXT: [[SPEC_SELECT:%.*]] = zext i1 [[X2]] to i32
; CHECK-NEXT: store i32 [[SPEC_SELECT]], ptr [[P:%.*]], align 4
@@ -53,7 +53,7 @@ define void @test_simple_commuted(ptr %p, i32 %a, i32 %b) {
;
entry:
%x1 = icmp eq i32 %a, 0
- br i1 %x1, label %yes1, label %fallthrough
+ br i1 %x1, label %yes1, label %fallthrough, !prof !0
yes1:
store i32 0, ptr %p
@@ -61,7 +61,7 @@ yes1:
fallthrough:
%x2 = icmp eq i32 %b, 0
- br i1 %x2, label %yes2, label %end
+ br i1 %x2, label %yes2, label %end, !prof !1
yes2:
store i32 1, ptr %p
@@ -406,3 +406,9 @@ yes2:
end:
ret void
}
+
+!0 = !{!"branch_weights", i32 7, i32 13}
+!1 = !{!"branch_weights", i32 3, i32 11}
+;.
+; CHECK: [[PROF0]] = !{!"branch_weights", i32 259, i32 21}
+;.
|
24507f7 to
2495127
Compare
30960f9 to
74a231d
Compare
2495127 to
c2f5866
Compare
f33e18a to
5e759bb
Compare
ad7d4ee to
5ae9f83
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should probably change this to [[A:%.*]]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's autogenerated by update_test_checks.py, not much we can do here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my experience: This is a case where the script is confused by trying to respect the current FileCheck variable names and the LLVM IR names they are intended to match. If you fix them both (B->A and B1->B) before running the script again, it will usually respect the corrected names.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[[B1:%.]] --> [[B:%.]]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same - autogen-ed.
38e0ee4 to
55e88b2
Compare
| auto *T = SplitBlockAndInsertIfThen(CombinedPred, InsertPt, | ||
| /*Unreachable=*/false, | ||
| /*BranchWeights=*/nullptr, DTU); | ||
| if (hasBranchWeightMD(*PBranch) && hasBranchWeightMD(*QBranch) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably useful to have a utility function getDisjunctionWeights with instructions in the interface. I can see it being used elsewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I couldn't use it here though, the conditions need flipping.
| // for the first conditional branch, the probability the "true" case is taken | ||
| // is p(b1) = B1[0] / (B1[0] + B2[0]). The "false" case's probability is | ||
| // p(not b1) = B1[1] / (B1[0] + B1[1]). | ||
| // Similarly for the second conditional branch and B2. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest moving the above comments to the outer doxygen comments. That is, if someone does not understand this, they probably won't call the function with correct arguments. In contrast, the comments that follow are more about the implementation not the interface. [Edit: To clarify my point: what is missing from the outer comments is which weight is for true vs. false in each of B1 and B2. I suppose the exact formulas above are not really needed in the doxygen comments.]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ptal
03afd6e to
f3540ec
Compare
llvm/include/llvm/IR/ProfDataUtils.h
Outdated
| #ifndef LLVM_IR_PROFDATAUTILS_H | ||
| #define LLVM_IR_PROFDATAUTILS_H | ||
|
|
||
| #include "llvm/ADT/STLExtras.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this include for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed, same for Twine below.
| cl::desc("Limit number of blocks a define in a threaded block is allowed " | ||
| "to be live in")); | ||
|
|
||
| extern cl::opt<bool> ProfcheckDisableMetadataFixes; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the value of setting this flag to true? Should there be a test for that case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's solely for a one-off ablation test I want to do at the end of this, to evaluate the performance impact of improving !prof propagation, after which I'll remove it. It's false by default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That plan seems worth a brief source comment here, possibly with a link to the larger discussion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, it's an extern, and the comment I want is on the definition in a separate .cpp file. Why is the extern not in a .h (preferably with the comment for all uses to easily see)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wanted to keep the churn to a minimum, the flag is basically an eyesore but it'll serve its purpose. I'm not opposed to moving it to a header. Just not sure if it's worth it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's short lived, then maybe it's not worth changing now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have not fully digested the surrounding code, but this LGTM, other than the nits I pointed out.
557c37c to
abe428a
Compare
abe428a to
8d1aaac
Compare

No description provided.