Skip to content

Commit 889c289

Browse files
authored
[SimplfyCFG] Set MD_prof for select used for certain conditional simplifications (#154426)
There’s a pattern where a branch is conditioned on a conjunction or disjunction that ends up being modeled as a `select`​ where the first operand is set to `true`​ or the second to `false`​. If the branch has known branch weights, they can be copied to the `select`​. This is worth doing in case later the `select`​ gets transformed to something else (i.e. if we know the profile, we should propagate it). Issue #147390
1 parent 179f01b commit 889c289

File tree

4 files changed

+115
-62
lines changed

4 files changed

+115
-62
lines changed

llvm/lib/Transforms/Utils/SimplifyCFG.cpp

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,16 @@ class SimplifyCFGOpt {
332332
}
333333
};
334334

335+
// we synthesize a || b as select a, true, b
336+
// we synthesize a && b as select a, b, false
337+
// this function determines if SI is playing one of those roles.
338+
bool isSelectInRoleOfConjunctionOrDisjunction(const SelectInst *SI) {
339+
return ((isa<ConstantInt>(SI->getTrueValue()) &&
340+
(dyn_cast<ConstantInt>(SI->getTrueValue())->isOne())) ||
341+
(isa<ConstantInt>(SI->getFalseValue()) &&
342+
(dyn_cast<ConstantInt>(SI->getFalseValue())->isNullValue())));
343+
}
344+
335345
} // end anonymous namespace
336346

337347
/// Return true if all the PHI nodes in the basic block \p BB
@@ -4033,6 +4043,7 @@ static bool performBranchToCommonDestFolding(BranchInst *BI, BranchInst *PBI,
40334043

40344044
// Try to update branch weights.
40354045
uint64_t PredTrueWeight, PredFalseWeight, SuccTrueWeight, SuccFalseWeight;
4046+
SmallVector<uint32_t, 2> MDWeights;
40364047
if (extractPredSuccWeights(PBI, BI, PredTrueWeight, PredFalseWeight,
40374048
SuccTrueWeight, SuccFalseWeight)) {
40384049
SmallVector<uint64_t, 8> NewWeights;
@@ -4063,7 +4074,7 @@ static bool performBranchToCommonDestFolding(BranchInst *BI, BranchInst *PBI,
40634074
// Halve the weights if any of them cannot fit in an uint32_t
40644075
fitWeights(NewWeights);
40654076

4066-
SmallVector<uint32_t, 8> MDWeights(NewWeights.begin(), NewWeights.end());
4077+
append_range(MDWeights, NewWeights);
40674078
setBranchWeights(PBI, MDWeights[0], MDWeights[1], /*IsExpected=*/false);
40684079

40694080
// TODO: If BB is reachable from all paths through PredBlock, then we
@@ -4100,6 +4111,13 @@ static bool performBranchToCommonDestFolding(BranchInst *BI, BranchInst *PBI,
41004111
Value *BICond = VMap[BI->getCondition()];
41014112
PBI->setCondition(
41024113
createLogicalOp(Builder, Opc, PBI->getCondition(), BICond, "or.cond"));
4114+
if (!ProfcheckDisableMetadataFixes)
4115+
if (auto *SI = dyn_cast<SelectInst>(PBI->getCondition()))
4116+
if (!MDWeights.empty()) {
4117+
assert(isSelectInRoleOfConjunctionOrDisjunction(SI));
4118+
setBranchWeights(SI, MDWeights[0], MDWeights[1],
4119+
/*IsExpected=*/false);
4120+
}
41034121

41044122
++NumFoldBranchToCommonDest;
41054123
return true;
@@ -4812,6 +4830,18 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI,
48124830
fitWeights(NewWeights);
48134831

48144832
setBranchWeights(PBI, NewWeights[0], NewWeights[1], /*IsExpected=*/false);
4833+
// Cond may be a select instruction with the first operand set to "true", or
4834+
// the second to "false" (see how createLogicalOp works for `and` and `or`)
4835+
if (!ProfcheckDisableMetadataFixes)
4836+
if (auto *SI = dyn_cast<SelectInst>(Cond)) {
4837+
assert(isSelectInRoleOfConjunctionOrDisjunction(SI));
4838+
// The select is predicated on PBICond
4839+
assert(dyn_cast<SelectInst>(SI)->getCondition() == PBICond);
4840+
// The corresponding probabilities are what was referred to above as
4841+
// PredCommon and PredOther.
4842+
setBranchWeights(SI, PredCommon, PredOther,
4843+
/*IsExpected=*/false);
4844+
}
48154845
}
48164846

48174847
// OtherDest may have phi nodes. If so, add an entry from PBI's

llvm/test/Transforms/SimplifyCFG/branch-fold-threshold.ll

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals all --version 5
22
; RUN: opt %s -passes=simplifycfg -simplifycfg-require-and-preserve-domtree=1 -S | FileCheck %s --check-prefixes=NORMAL,BASELINE
33
; RUN: opt %s -passes=simplifycfg -simplifycfg-require-and-preserve-domtree=1 -S -bonus-inst-threshold=2 | FileCheck %s --check-prefixes=NORMAL,AGGRESSIVE
44
; RUN: opt %s -passes=simplifycfg -simplifycfg-require-and-preserve-domtree=1 -S -bonus-inst-threshold=4 | FileCheck %s --check-prefixes=WAYAGGRESSIVE
@@ -11,12 +11,12 @@ define i32 @foo(i32 %a, i32 %b, i32 %c, i32 %d, ptr %input) {
1111
; BASELINE-SAME: i32 [[A:%.*]], i32 [[B:%.*]], i32 [[C:%.*]], i32 [[D:%.*]], ptr [[INPUT:%.*]]) {
1212
; BASELINE-NEXT: [[ENTRY:.*]]:
1313
; BASELINE-NEXT: [[CMP:%.*]] = icmp sgt i32 [[D]], 3
14-
; BASELINE-NEXT: br i1 [[CMP]], label %[[COND_END:.*]], label %[[LOR_LHS_FALSE:.*]]
14+
; BASELINE-NEXT: br i1 [[CMP]], label %[[COND_END:.*]], label %[[LOR_LHS_FALSE:.*]], !prof [[PROF0:![0-9]+]]
1515
; BASELINE: [[LOR_LHS_FALSE]]:
1616
; BASELINE-NEXT: [[MUL:%.*]] = shl i32 [[C]], 1
1717
; BASELINE-NEXT: [[ADD:%.*]] = add nsw i32 [[MUL]], [[A]]
1818
; BASELINE-NEXT: [[CMP1:%.*]] = icmp slt i32 [[ADD]], [[B]]
19-
; BASELINE-NEXT: br i1 [[CMP1]], label %[[COND_FALSE:.*]], label %[[COND_END]]
19+
; BASELINE-NEXT: br i1 [[CMP1]], label %[[COND_FALSE:.*]], label %[[COND_END]], !prof [[PROF1:![0-9]+]]
2020
; BASELINE: [[COND_FALSE]]:
2121
; BASELINE-NEXT: [[TMP0:%.*]] = load i32, ptr [[INPUT]], align 4
2222
; BASELINE-NEXT: br label %[[COND_END]]
@@ -31,8 +31,8 @@ define i32 @foo(i32 %a, i32 %b, i32 %c, i32 %d, ptr %input) {
3131
; AGGRESSIVE-NEXT: [[MUL:%.*]] = shl i32 [[C]], 1
3232
; AGGRESSIVE-NEXT: [[ADD:%.*]] = add nsw i32 [[MUL]], [[A]]
3333
; AGGRESSIVE-NEXT: [[CMP1:%.*]] = icmp slt i32 [[ADD]], [[B]]
34-
; AGGRESSIVE-NEXT: [[OR_COND:%.*]] = select i1 [[CMP]], i1 [[CMP1]], i1 false
35-
; AGGRESSIVE-NEXT: br i1 [[OR_COND]], label %[[COND_FALSE:.*]], label %[[COND_END:.*]]
34+
; AGGRESSIVE-NEXT: [[OR_COND:%.*]] = select i1 [[CMP]], i1 [[CMP1]], i1 false, !prof [[PROF0:![0-9]+]]
35+
; AGGRESSIVE-NEXT: br i1 [[OR_COND]], label %[[COND_FALSE:.*]], label %[[COND_END:.*]], !prof [[PROF0]]
3636
; AGGRESSIVE: [[COND_FALSE]]:
3737
; AGGRESSIVE-NEXT: [[TMP0:%.*]] = load i32, ptr [[INPUT]], align 4
3838
; AGGRESSIVE-NEXT: br label %[[COND_END]]
@@ -47,8 +47,8 @@ define i32 @foo(i32 %a, i32 %b, i32 %c, i32 %d, ptr %input) {
4747
; WAYAGGRESSIVE-NEXT: [[MUL:%.*]] = shl i32 [[C]], 1
4848
; WAYAGGRESSIVE-NEXT: [[ADD:%.*]] = add nsw i32 [[MUL]], [[A]]
4949
; WAYAGGRESSIVE-NEXT: [[CMP1:%.*]] = icmp slt i32 [[ADD]], [[B]]
50-
; WAYAGGRESSIVE-NEXT: [[OR_COND:%.*]] = select i1 [[CMP]], i1 [[CMP1]], i1 false
51-
; WAYAGGRESSIVE-NEXT: br i1 [[OR_COND]], label %[[COND_FALSE:.*]], label %[[COND_END:.*]]
50+
; WAYAGGRESSIVE-NEXT: [[OR_COND:%.*]] = select i1 [[CMP]], i1 [[CMP1]], i1 false, !prof [[PROF0:![0-9]+]]
51+
; WAYAGGRESSIVE-NEXT: br i1 [[OR_COND]], label %[[COND_FALSE:.*]], label %[[COND_END:.*]], !prof [[PROF0]]
5252
; WAYAGGRESSIVE: [[COND_FALSE]]:
5353
; WAYAGGRESSIVE-NEXT: [[TMP0:%.*]] = load i32, ptr [[INPUT]], align 4
5454
; WAYAGGRESSIVE-NEXT: br label %[[COND_END]]
@@ -58,13 +58,13 @@ define i32 @foo(i32 %a, i32 %b, i32 %c, i32 %d, ptr %input) {
5858
;
5959
entry:
6060
%cmp = icmp sgt i32 %d, 3
61-
br i1 %cmp, label %cond.end, label %lor.lhs.false
61+
br i1 %cmp, label %cond.end, label %lor.lhs.false, !prof !0
6262

6363
lor.lhs.false:
6464
%mul = shl i32 %c, 1
6565
%add = add nsw i32 %mul, %a
6666
%cmp1 = icmp slt i32 %add, %b
67-
br i1 %cmp1, label %cond.false, label %cond.end
67+
br i1 %cmp1, label %cond.false, label %cond.end, !prof !1
6868

6969
cond.false:
7070
%0 = load i32, ptr %input, align 4
@@ -160,3 +160,14 @@ cond.end:
160160
%cond = phi i32 [ %0, %cond.false ], [ 0, %lor.lhs.false ],[ 0, %pred_a ],[ 0, %pred_b ]
161161
ret i32 %cond
162162
}
163+
164+
!0 = !{!"branch_weights", i32 7, i32 11}
165+
!1 = !{!"branch_weights", i32 13, i32 5}
166+
;.
167+
; BASELINE: [[PROF0]] = !{!"branch_weights", i32 7, i32 11}
168+
; BASELINE: [[PROF1]] = !{!"branch_weights", i32 13, i32 5}
169+
;.
170+
; AGGRESSIVE: [[PROF0]] = !{!"branch_weights", i32 143, i32 181}
171+
;.
172+
; WAYAGGRESSIVE: [[PROF0]] = !{!"branch_weights", i32 143, i32 181}
173+
;.

llvm/test/Transforms/SimplifyCFG/branch-fold.ll

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals
22
; RUN: opt < %s -passes=simplifycfg -simplifycfg-require-and-preserve-domtree=1 -S | FileCheck %s
33

44
define void @test(ptr %P, ptr %Q, i1 %A, i1 %B) {
55
; CHECK-LABEL: @test(
66
; CHECK-NEXT: entry:
77
; CHECK-NEXT: [[A_NOT:%.*]] = xor i1 [[A:%.*]], true
8-
; CHECK-NEXT: [[BRMERGE:%.*]] = select i1 [[A_NOT]], i1 true, i1 [[B:%.*]]
9-
; CHECK-NEXT: br i1 [[BRMERGE]], label [[B:%.*]], label [[COMMON_RET:%.*]]
8+
; CHECK-NEXT: [[BRMERGE:%.*]] = select i1 [[A_NOT]], i1 true, i1 [[B:%.*]], !prof [[PROF0:![0-9]+]]
9+
; CHECK-NEXT: br i1 [[BRMERGE]], label [[B:%.*]], label [[COMMON_RET:%.*]], !prof [[PROF1:![0-9]+]]
1010
; CHECK: common.ret:
1111
; CHECK-NEXT: ret void
1212
; CHECK: b:
@@ -15,9 +15,9 @@ define void @test(ptr %P, ptr %Q, i1 %A, i1 %B) {
1515
;
1616

1717
entry:
18-
br i1 %A, label %a, label %b
18+
br i1 %A, label %a, label %b, !prof !0
1919
a:
20-
br i1 %B, label %b, label %c
20+
br i1 %B, label %b, label %c, !prof !1
2121
b:
2222
store i32 123, ptr %P
2323
ret void
@@ -146,3 +146,12 @@ Succ:
146146
}
147147

148148
declare void @dummy()
149+
150+
!0 = !{!"branch_weights", i32 3, i32 7}
151+
!1 = !{!"branch_weights", i32 11, i32 4}
152+
;.
153+
; CHECK: attributes #[[ATTR0:[0-9]+]] = { nounwind ssp memory(read) uwtable }
154+
;.
155+
; CHECK: [[PROF0]] = !{!"branch_weights", i32 7, i32 3}
156+
; CHECK: [[PROF1]] = !{!"branch_weights", i32 138, i32 12}
157+
;.

0 commit comments

Comments
 (0)