8484#include < cstdint>
8585#include < iterator>
8686#include < map>
87+ #include < numeric>
8788#include < optional>
8889#include < set>
8990#include < tuple>
@@ -6329,9 +6330,12 @@ static bool initializeUniqueCases(SwitchInst *SI, PHINode *&PHI,
63296330// Helper function that checks if it is possible to transform a switch with only
63306331// two cases (or two cases + default) that produces a result into a select.
63316332// TODO: Handle switches with more than 2 cases that map to the same result.
6333+ // The branch weights correspond to the provided Condition (i.e. if Condition is
6334+ // modified from the original SwitchInst, the caller must adjust the weights)
63326335static Value *foldSwitchToSelect (const SwitchCaseResultVectorTy &ResultVector,
63336336 Constant *DefaultResult, Value *Condition,
6334- IRBuilder<> &Builder, const DataLayout &DL) {
6337+ IRBuilder<> &Builder, const DataLayout &DL,
6338+ ArrayRef<uint32_t > BranchWeights) {
63356339 // If we are selecting between only two cases transform into a simple
63366340 // select or a two-way select if default is possible.
63376341 // Example:
@@ -6340,6 +6344,10 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
63406344 // case 20: return 2; ----> %2 = icmp eq i32 %a, 20
63416345 // default: return 4; %3 = select i1 %2, i32 2, i32 %1
63426346 // }
6347+
6348+ const bool HasBranchWeights =
6349+ !BranchWeights.empty () && !ProfcheckDisableMetadataFixes;
6350+
63436351 if (ResultVector.size () == 2 && ResultVector[0 ].second .size () == 1 &&
63446352 ResultVector[1 ].second .size () == 1 ) {
63456353 ConstantInt *FirstCase = ResultVector[0 ].second [0 ];
@@ -6348,13 +6356,37 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
63486356 if (DefaultResult) {
63496357 Value *ValueCompare =
63506358 Builder.CreateICmpEQ (Condition, SecondCase, " switch.selectcmp" );
6351- SelectValue = Builder.CreateSelect (ValueCompare, ResultVector[1 ].first ,
6352- DefaultResult, " switch.select" );
6359+ SelectInst *SelectValueInst = cast<SelectInst>(Builder.CreateSelect (
6360+ ValueCompare, ResultVector[1 ].first , DefaultResult, " switch.select" ));
6361+ SelectValue = SelectValueInst;
6362+ if (HasBranchWeights) {
6363+ // We start with 3 probabilities, where the numerator is the
6364+ // corresponding BranchWeights[i], and the denominator is the sum over
6365+ // BranchWeights. We want the probability and negative probability of
6366+ // Condition == SecondCase.
6367+ assert (BranchWeights.size () == 3 );
6368+ setBranchWeights (SelectValueInst, BranchWeights[2 ],
6369+ BranchWeights[0 ] + BranchWeights[1 ],
6370+ /* IsExpected=*/ false );
6371+ }
63536372 }
63546373 Value *ValueCompare =
63556374 Builder.CreateICmpEQ (Condition, FirstCase, " switch.selectcmp" );
6356- return Builder.CreateSelect (ValueCompare, ResultVector[0 ].first ,
6357- SelectValue, " switch.select" );
6375+ SelectInst *Ret = cast<SelectInst>(Builder.CreateSelect (
6376+ ValueCompare, ResultVector[0 ].first , SelectValue, " switch.select" ));
6377+ if (HasBranchWeights) {
6378+ // We may have had a DefaultResult. Base the position of the first and
6379+ // second's branch weights accordingly. Also the proability that Condition
6380+ // != FirstCase needs to take that into account.
6381+ assert (BranchWeights.size () >= 2 );
6382+ size_t FirstCasePos = (Condition != nullptr );
6383+ size_t SecondCasePos = FirstCasePos + 1 ;
6384+ uint32_t DefaultCase = (Condition != nullptr ) ? BranchWeights[0 ] : 0 ;
6385+ setBranchWeights (Ret, BranchWeights[FirstCasePos],
6386+ DefaultCase + BranchWeights[SecondCasePos],
6387+ /* IsExpected=*/ false );
6388+ }
6389+ return Ret;
63586390 }
63596391
63606392 // Handle the degenerate case where two cases have the same result value.
@@ -6390,8 +6422,16 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
63906422 Value *And = Builder.CreateAnd (Condition, AndMask);
63916423 Value *Cmp = Builder.CreateICmpEQ (
63926424 And, Constant::getIntegerValue (And->getType (), AndMask));
6393- return Builder.CreateSelect (Cmp, ResultVector[0 ].first ,
6394- DefaultResult);
6425+ SelectInst *Ret = cast<SelectInst>(
6426+ Builder.CreateSelect (Cmp, ResultVector[0 ].first , DefaultResult));
6427+ if (HasBranchWeights) {
6428+ // We know there's a Default case. We base the resulting branch
6429+ // weights off its probability.
6430+ assert (BranchWeights.size () >= 2 );
6431+ setBranchWeights (Ret, accumulate (drop_begin (BranchWeights), 0 ),
6432+ BranchWeights[0 ], /* IsExpected=*/ false );
6433+ }
6434+ return Ret;
63956435 }
63966436 }
63976437
@@ -6408,7 +6448,14 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
64086448 Value *And = Builder.CreateAnd (Condition, ~BitMask, " switch.and" );
64096449 Value *Cmp = Builder.CreateICmpEQ (
64106450 And, Constant::getNullValue (And->getType ()), " switch.selectcmp" );
6411- return Builder.CreateSelect (Cmp, ResultVector[0 ].first , DefaultResult);
6451+ SelectInst *Ret = cast<SelectInst>(
6452+ Builder.CreateSelect (Cmp, ResultVector[0 ].first , DefaultResult));
6453+ if (HasBranchWeights) {
6454+ assert (BranchWeights.size () >= 2 );
6455+ setBranchWeights (Ret, accumulate (drop_begin (BranchWeights), 0 ),
6456+ BranchWeights[0 ], /* IsExpected=*/ false );
6457+ }
6458+ return Ret;
64126459 }
64136460 }
64146461
@@ -6419,7 +6466,14 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
64196466 Value *Cmp2 = Builder.CreateICmpEQ (Condition, CaseValues[1 ],
64206467 " switch.selectcmp.case2" );
64216468 Value *Cmp = Builder.CreateOr (Cmp1, Cmp2, " switch.selectcmp" );
6422- return Builder.CreateSelect (Cmp, ResultVector[0 ].first , DefaultResult);
6469+ SelectInst *Ret = cast<SelectInst>(
6470+ Builder.CreateSelect (Cmp, ResultVector[0 ].first , DefaultResult));
6471+ if (HasBranchWeights) {
6472+ assert (BranchWeights.size () >= 2 );
6473+ setBranchWeights (Ret, accumulate (drop_begin (BranchWeights), 0 ),
6474+ BranchWeights[0 ], /* IsExpected=*/ false );
6475+ }
6476+ return Ret;
64236477 }
64246478 }
64256479
@@ -6480,8 +6534,18 @@ static bool trySwitchToSelect(SwitchInst *SI, IRBuilder<> &Builder,
64806534
64816535 assert (PHI != nullptr && " PHI for value select not found" );
64826536 Builder.SetInsertPoint (SI);
6483- Value *SelectValue =
6484- foldSwitchToSelect (UniqueResults, DefaultResult, Cond, Builder, DL);
6537+ SmallVector<uint32_t , 4 > BranchWeights;
6538+ if (!ProfcheckDisableMetadataFixes) {
6539+ [[maybe_unused]] auto HasWeights =
6540+ extractBranchWeights (getBranchWeightMDNode (*SI), BranchWeights);
6541+ assert (!HasWeights == (BranchWeights.empty ()));
6542+ }
6543+ assert (BranchWeights.empty () ||
6544+ (BranchWeights.size () >=
6545+ UniqueResults.size () + (DefaultResult != nullptr )));
6546+
6547+ Value *SelectValue = foldSwitchToSelect (UniqueResults, DefaultResult, Cond,
6548+ Builder, DL, BranchWeights);
64856549 if (!SelectValue)
64866550 return false ;
64876551
0 commit comments