Skip to content

Commit d2f14bc

Browse files
authored
[profcheck][SimplifyCFG] Propagate !prof from switch to select (#159645)
Propagate `!prof`​ from `switch`​ instructions. Issue #147390
1 parent e05bad4 commit d2f14bc

File tree

2 files changed

+195
-106
lines changed

2 files changed

+195
-106
lines changed

llvm/lib/Transforms/Utils/SimplifyCFG.cpp

Lines changed: 75 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
#include <cstdint>
8585
#include <iterator>
8686
#include <map>
87+
#include <numeric>
8788
#include <optional>
8889
#include <set>
8990
#include <tuple>
@@ -6329,9 +6330,12 @@ static bool initializeUniqueCases(SwitchInst *SI, PHINode *&PHI,
63296330
// Helper function that checks if it is possible to transform a switch with only
63306331
// two cases (or two cases + default) that produces a result into a select.
63316332
// TODO: Handle switches with more than 2 cases that map to the same result.
6333+
// The branch weights correspond to the provided Condition (i.e. if Condition is
6334+
// modified from the original SwitchInst, the caller must adjust the weights)
63326335
static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
63336336
Constant *DefaultResult, Value *Condition,
6334-
IRBuilder<> &Builder, const DataLayout &DL) {
6337+
IRBuilder<> &Builder, const DataLayout &DL,
6338+
ArrayRef<uint32_t> BranchWeights) {
63356339
// If we are selecting between only two cases transform into a simple
63366340
// select or a two-way select if default is possible.
63376341
// Example:
@@ -6340,6 +6344,10 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
63406344
// case 20: return 2; ----> %2 = icmp eq i32 %a, 20
63416345
// default: return 4; %3 = select i1 %2, i32 2, i32 %1
63426346
// }
6347+
6348+
const bool HasBranchWeights =
6349+
!BranchWeights.empty() && !ProfcheckDisableMetadataFixes;
6350+
63436351
if (ResultVector.size() == 2 && ResultVector[0].second.size() == 1 &&
63446352
ResultVector[1].second.size() == 1) {
63456353
ConstantInt *FirstCase = ResultVector[0].second[0];
@@ -6348,13 +6356,37 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
63486356
if (DefaultResult) {
63496357
Value *ValueCompare =
63506358
Builder.CreateICmpEQ(Condition, SecondCase, "switch.selectcmp");
6351-
SelectValue = Builder.CreateSelect(ValueCompare, ResultVector[1].first,
6352-
DefaultResult, "switch.select");
6359+
SelectInst *SelectValueInst = cast<SelectInst>(Builder.CreateSelect(
6360+
ValueCompare, ResultVector[1].first, DefaultResult, "switch.select"));
6361+
SelectValue = SelectValueInst;
6362+
if (HasBranchWeights) {
6363+
// We start with 3 probabilities, where the numerator is the
6364+
// corresponding BranchWeights[i], and the denominator is the sum over
6365+
// BranchWeights. We want the probability and negative probability of
6366+
// Condition == SecondCase.
6367+
assert(BranchWeights.size() == 3);
6368+
setBranchWeights(SelectValueInst, BranchWeights[2],
6369+
BranchWeights[0] + BranchWeights[1],
6370+
/*IsExpected=*/false);
6371+
}
63536372
}
63546373
Value *ValueCompare =
63556374
Builder.CreateICmpEQ(Condition, FirstCase, "switch.selectcmp");
6356-
return Builder.CreateSelect(ValueCompare, ResultVector[0].first,
6357-
SelectValue, "switch.select");
6375+
SelectInst *Ret = cast<SelectInst>(Builder.CreateSelect(
6376+
ValueCompare, ResultVector[0].first, SelectValue, "switch.select"));
6377+
if (HasBranchWeights) {
6378+
// We may have had a DefaultResult. Base the position of the first and
6379+
// second's branch weights accordingly. Also the proability that Condition
6380+
// != FirstCase needs to take that into account.
6381+
assert(BranchWeights.size() >= 2);
6382+
size_t FirstCasePos = (Condition != nullptr);
6383+
size_t SecondCasePos = FirstCasePos + 1;
6384+
uint32_t DefaultCase = (Condition != nullptr) ? BranchWeights[0] : 0;
6385+
setBranchWeights(Ret, BranchWeights[FirstCasePos],
6386+
DefaultCase + BranchWeights[SecondCasePos],
6387+
/*IsExpected=*/false);
6388+
}
6389+
return Ret;
63586390
}
63596391

63606392
// Handle the degenerate case where two cases have the same result value.
@@ -6390,8 +6422,16 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
63906422
Value *And = Builder.CreateAnd(Condition, AndMask);
63916423
Value *Cmp = Builder.CreateICmpEQ(
63926424
And, Constant::getIntegerValue(And->getType(), AndMask));
6393-
return Builder.CreateSelect(Cmp, ResultVector[0].first,
6394-
DefaultResult);
6425+
SelectInst *Ret = cast<SelectInst>(
6426+
Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult));
6427+
if (HasBranchWeights) {
6428+
// We know there's a Default case. We base the resulting branch
6429+
// weights off its probability.
6430+
assert(BranchWeights.size() >= 2);
6431+
setBranchWeights(Ret, accumulate(drop_begin(BranchWeights), 0),
6432+
BranchWeights[0], /*IsExpected=*/false);
6433+
}
6434+
return Ret;
63956435
}
63966436
}
63976437

@@ -6408,7 +6448,14 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
64086448
Value *And = Builder.CreateAnd(Condition, ~BitMask, "switch.and");
64096449
Value *Cmp = Builder.CreateICmpEQ(
64106450
And, Constant::getNullValue(And->getType()), "switch.selectcmp");
6411-
return Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult);
6451+
SelectInst *Ret = cast<SelectInst>(
6452+
Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult));
6453+
if (HasBranchWeights) {
6454+
assert(BranchWeights.size() >= 2);
6455+
setBranchWeights(Ret, accumulate(drop_begin(BranchWeights), 0),
6456+
BranchWeights[0], /*IsExpected=*/false);
6457+
}
6458+
return Ret;
64126459
}
64136460
}
64146461

@@ -6419,7 +6466,14 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
64196466
Value *Cmp2 = Builder.CreateICmpEQ(Condition, CaseValues[1],
64206467
"switch.selectcmp.case2");
64216468
Value *Cmp = Builder.CreateOr(Cmp1, Cmp2, "switch.selectcmp");
6422-
return Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult);
6469+
SelectInst *Ret = cast<SelectInst>(
6470+
Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult));
6471+
if (HasBranchWeights) {
6472+
assert(BranchWeights.size() >= 2);
6473+
setBranchWeights(Ret, accumulate(drop_begin(BranchWeights), 0),
6474+
BranchWeights[0], /*IsExpected=*/false);
6475+
}
6476+
return Ret;
64236477
}
64246478
}
64256479

@@ -6480,8 +6534,18 @@ static bool trySwitchToSelect(SwitchInst *SI, IRBuilder<> &Builder,
64806534

64816535
assert(PHI != nullptr && "PHI for value select not found");
64826536
Builder.SetInsertPoint(SI);
6483-
Value *SelectValue =
6484-
foldSwitchToSelect(UniqueResults, DefaultResult, Cond, Builder, DL);
6537+
SmallVector<uint32_t, 4> BranchWeights;
6538+
if (!ProfcheckDisableMetadataFixes) {
6539+
[[maybe_unused]] auto HasWeights =
6540+
extractBranchWeights(getBranchWeightMDNode(*SI), BranchWeights);
6541+
assert(!HasWeights == (BranchWeights.empty()));
6542+
}
6543+
assert(BranchWeights.empty() ||
6544+
(BranchWeights.size() >=
6545+
UniqueResults.size() + (DefaultResult != nullptr)));
6546+
6547+
Value *SelectValue = foldSwitchToSelect(UniqueResults, DefaultResult, Cond,
6548+
Builder, DL, BranchWeights);
64856549
if (!SelectValue)
64866550
return false;
64876551

0 commit comments

Comments
 (0)