84
84
#include < cstdint>
85
85
#include < iterator>
86
86
#include < map>
87
+ #include < numeric>
87
88
#include < optional>
88
89
#include < set>
89
90
#include < tuple>
@@ -6329,9 +6330,12 @@ static bool initializeUniqueCases(SwitchInst *SI, PHINode *&PHI,
6329
6330
// Helper function that checks if it is possible to transform a switch with only
6330
6331
// two cases (or two cases + default) that produces a result into a select.
6331
6332
// TODO: Handle switches with more than 2 cases that map to the same result.
6333
+ // The branch weights correspond to the provided Condition (i.e. if Condition is
6334
+ // modified from the original SwitchInst, the caller must adjust the weights)
6332
6335
static Value *foldSwitchToSelect (const SwitchCaseResultVectorTy &ResultVector,
6333
6336
Constant *DefaultResult, Value *Condition,
6334
- IRBuilder<> &Builder, const DataLayout &DL) {
6337
+ IRBuilder<> &Builder, const DataLayout &DL,
6338
+ ArrayRef<uint32_t > BranchWeights) {
6335
6339
// If we are selecting between only two cases transform into a simple
6336
6340
// select or a two-way select if default is possible.
6337
6341
// Example:
@@ -6340,6 +6344,10 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
6340
6344
// case 20: return 2; ----> %2 = icmp eq i32 %a, 20
6341
6345
// default: return 4; %3 = select i1 %2, i32 2, i32 %1
6342
6346
// }
6347
+
6348
+ const bool HasBranchWeights =
6349
+ !BranchWeights.empty () && !ProfcheckDisableMetadataFixes;
6350
+
6343
6351
if (ResultVector.size () == 2 && ResultVector[0 ].second .size () == 1 &&
6344
6352
ResultVector[1 ].second .size () == 1 ) {
6345
6353
ConstantInt *FirstCase = ResultVector[0 ].second [0 ];
@@ -6348,13 +6356,37 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
6348
6356
if (DefaultResult) {
6349
6357
Value *ValueCompare =
6350
6358
Builder.CreateICmpEQ (Condition, SecondCase, " switch.selectcmp" );
6351
- SelectValue = Builder.CreateSelect (ValueCompare, ResultVector[1 ].first ,
6352
- DefaultResult, " switch.select" );
6359
+ SelectInst *SelectValueInst = cast<SelectInst>(Builder.CreateSelect (
6360
+ ValueCompare, ResultVector[1 ].first , DefaultResult, " switch.select" ));
6361
+ SelectValue = SelectValueInst;
6362
+ if (HasBranchWeights) {
6363
+ // We start with 3 probabilities, where the numerator is the
6364
+ // corresponding BranchWeights[i], and the denominator is the sum over
6365
+ // BranchWeights. We want the probability and negative probability of
6366
+ // Condition == SecondCase.
6367
+ assert (BranchWeights.size () == 3 );
6368
+ setBranchWeights (SelectValueInst, BranchWeights[2 ],
6369
+ BranchWeights[0 ] + BranchWeights[1 ],
6370
+ /* IsExpected=*/ false );
6371
+ }
6353
6372
}
6354
6373
Value *ValueCompare =
6355
6374
Builder.CreateICmpEQ (Condition, FirstCase, " switch.selectcmp" );
6356
- return Builder.CreateSelect (ValueCompare, ResultVector[0 ].first ,
6357
- SelectValue, " switch.select" );
6375
+ SelectInst *Ret = cast<SelectInst>(Builder.CreateSelect (
6376
+ ValueCompare, ResultVector[0 ].first , SelectValue, " switch.select" ));
6377
+ if (HasBranchWeights) {
6378
+ // We may have had a DefaultResult. Base the position of the first and
6379
+ // second's branch weights accordingly. Also the proability that Condition
6380
+ // != FirstCase needs to take that into account.
6381
+ assert (BranchWeights.size () >= 2 );
6382
+ size_t FirstCasePos = (Condition != nullptr );
6383
+ size_t SecondCasePos = FirstCasePos + 1 ;
6384
+ uint32_t DefaultCase = (Condition != nullptr ) ? BranchWeights[0 ] : 0 ;
6385
+ setBranchWeights (Ret, BranchWeights[FirstCasePos],
6386
+ DefaultCase + BranchWeights[SecondCasePos],
6387
+ /* IsExpected=*/ false );
6388
+ }
6389
+ return Ret;
6358
6390
}
6359
6391
6360
6392
// Handle the degenerate case where two cases have the same result value.
@@ -6390,8 +6422,16 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
6390
6422
Value *And = Builder.CreateAnd (Condition, AndMask);
6391
6423
Value *Cmp = Builder.CreateICmpEQ (
6392
6424
And, Constant::getIntegerValue (And->getType (), AndMask));
6393
- return Builder.CreateSelect (Cmp, ResultVector[0 ].first ,
6394
- DefaultResult);
6425
+ SelectInst *Ret = cast<SelectInst>(
6426
+ Builder.CreateSelect (Cmp, ResultVector[0 ].first , DefaultResult));
6427
+ if (HasBranchWeights) {
6428
+ // We know there's a Default case. We base the resulting branch
6429
+ // weights off its probability.
6430
+ assert (BranchWeights.size () >= 2 );
6431
+ setBranchWeights (Ret, accumulate (drop_begin (BranchWeights), 0 ),
6432
+ BranchWeights[0 ], /* IsExpected=*/ false );
6433
+ }
6434
+ return Ret;
6395
6435
}
6396
6436
}
6397
6437
@@ -6408,7 +6448,14 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
6408
6448
Value *And = Builder.CreateAnd (Condition, ~BitMask, " switch.and" );
6409
6449
Value *Cmp = Builder.CreateICmpEQ (
6410
6450
And, Constant::getNullValue (And->getType ()), " switch.selectcmp" );
6411
- return Builder.CreateSelect (Cmp, ResultVector[0 ].first , DefaultResult);
6451
+ SelectInst *Ret = cast<SelectInst>(
6452
+ Builder.CreateSelect (Cmp, ResultVector[0 ].first , DefaultResult));
6453
+ if (HasBranchWeights) {
6454
+ assert (BranchWeights.size () >= 2 );
6455
+ setBranchWeights (Ret, accumulate (drop_begin (BranchWeights), 0 ),
6456
+ BranchWeights[0 ], /* IsExpected=*/ false );
6457
+ }
6458
+ return Ret;
6412
6459
}
6413
6460
}
6414
6461
@@ -6419,7 +6466,14 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
6419
6466
Value *Cmp2 = Builder.CreateICmpEQ (Condition, CaseValues[1 ],
6420
6467
" switch.selectcmp.case2" );
6421
6468
Value *Cmp = Builder.CreateOr (Cmp1, Cmp2, " switch.selectcmp" );
6422
- return Builder.CreateSelect (Cmp, ResultVector[0 ].first , DefaultResult);
6469
+ SelectInst *Ret = cast<SelectInst>(
6470
+ Builder.CreateSelect (Cmp, ResultVector[0 ].first , DefaultResult));
6471
+ if (HasBranchWeights) {
6472
+ assert (BranchWeights.size () >= 2 );
6473
+ setBranchWeights (Ret, accumulate (drop_begin (BranchWeights), 0 ),
6474
+ BranchWeights[0 ], /* IsExpected=*/ false );
6475
+ }
6476
+ return Ret;
6423
6477
}
6424
6478
}
6425
6479
@@ -6480,8 +6534,18 @@ static bool trySwitchToSelect(SwitchInst *SI, IRBuilder<> &Builder,
6480
6534
6481
6535
assert (PHI != nullptr && " PHI for value select not found" );
6482
6536
Builder.SetInsertPoint (SI);
6483
- Value *SelectValue =
6484
- foldSwitchToSelect (UniqueResults, DefaultResult, Cond, Builder, DL);
6537
+ SmallVector<uint32_t , 4 > BranchWeights;
6538
+ if (!ProfcheckDisableMetadataFixes) {
6539
+ [[maybe_unused]] auto HasWeights =
6540
+ extractBranchWeights (getBranchWeightMDNode (*SI), BranchWeights);
6541
+ assert (!HasWeights == (BranchWeights.empty ()));
6542
+ }
6543
+ assert (BranchWeights.empty () ||
6544
+ (BranchWeights.size () >=
6545
+ UniqueResults.size () + (DefaultResult != nullptr )));
6546
+
6547
+ Value *SelectValue = foldSwitchToSelect (UniqueResults, DefaultResult, Cond,
6548
+ Builder, DL, BranchWeights);
6485
6549
if (!SelectValue)
6486
6550
return false ;
6487
6551
0 commit comments