Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions llvm/lib/Transforms/Utils/SimplifyCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7731,19 +7731,24 @@ static bool simplifySwitchOfPowersOfTwo(SwitchInst *SI, IRBuilder<> &Builder,
// label. The other is those powers of 2 that don't appear in the case
// statement. We don't know the distribution of the values coming in, so
// the safest is to split 50-50 the original probability to `default`.
uint64_t OrigDenominator = sum_of(map_range(
Weights, [](const auto &V) { return static_cast<uint64_t>(V); }));
uint64_t OrigDenominator =
sum_of(map_range(Weights, StaticCastTo<uint64_t>));
SmallVector<uint64_t> NewWeights(2);
NewWeights[1] = Weights[0] / 2;
NewWeights[0] = OrigDenominator - NewWeights[1];
setFittedBranchWeights(*BI, NewWeights, /*IsExpected=*/false);

// For the original switch, we reduce the weight of the default by the
// amount by which the previous branch contributes to getting to default,
// and then make sure the remaining weights have the same relative ratio
// wrt eachother.
// The probability of executing the default block stays constant. It was
// p_d = Weights[0] / OrigDenominator
// we rewrite as W/D
// We want to find the probability of the default branch of the switch
// statement. Let's call it X. We have W/D = W/2D + X * (1-W/2D)
// i.e. the original probability is the probability we go to the default
// branch from the BI branch, or we take the default branch on the SI.
// Meaning X = W / (2D - W), or (W/2) / (D - W/2)
// This matches using W/2 for the default branch probability numerator and
// D-W/2 as the denominator.
Weights[0] = NewWeights[1];
uint64_t CasesDenominator = OrigDenominator - Weights[0];
Weights[0] /= 2;
for (auto &W : drop_begin(Weights))
W = NewWeights[0] * static_cast<double>(W) / CasesDenominator;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,5 +141,5 @@ return:
;.
; CHECK: [[PROF0]] = !{!"function_entry_count", i32 10}
; CHECK: [[PROF1]] = !{!"branch_weights", i32 58, i32 5}
; CHECK: [[PROF2]] = !{!"branch_weights", i32 56, i32 5}
; CHECK: [[PROF2]] = !{!"branch_weights", i32 53, i32 5}
;.
Loading