Skip to content

Commit fdb2ab1

Browse files
antoniofrighettomahesh-attarde
authored andcommitted
[SimplifyCFG] Ensure selects have not been constant folded in foldSwitchToSelect
Make sure selects do exist prior to assigning weights to edges. Fixes: llvm#161137.
1 parent ab199f8 commit fdb2ab1

File tree

2 files changed

+40
-22
lines changed

2 files changed

+40
-22
lines changed

llvm/lib/Transforms/Utils/SimplifyCFG.cpp

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@
8484
#include <cstdint>
8585
#include <iterator>
8686
#include <map>
87-
#include <numeric>
8887
#include <optional>
8988
#include <set>
9089
#include <tuple>
@@ -6356,33 +6355,33 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
63566355
if (DefaultResult) {
63576356
Value *ValueCompare =
63586357
Builder.CreateICmpEQ(Condition, SecondCase, "switch.selectcmp");
6359-
SelectInst *SelectValueInst = cast<SelectInst>(Builder.CreateSelect(
6360-
ValueCompare, ResultVector[1].first, DefaultResult, "switch.select"));
6361-
SelectValue = SelectValueInst;
6362-
if (HasBranchWeights) {
6358+
SelectValue = Builder.CreateSelect(ValueCompare, ResultVector[1].first,
6359+
DefaultResult, "switch.select");
6360+
if (auto *SI = dyn_cast<SelectInst>(SelectValue);
6361+
SI && HasBranchWeights) {
63636362
// We start with 3 probabilities, where the numerator is the
63646363
// corresponding BranchWeights[i], and the denominator is the sum over
63656364
// BranchWeights. We want the probability and negative probability of
63666365
// Condition == SecondCase.
63676366
assert(BranchWeights.size() == 3);
6368-
setBranchWeights(SelectValueInst, BranchWeights[2],
6367+
setBranchWeights(SI, BranchWeights[2],
63696368
BranchWeights[0] + BranchWeights[1],
63706369
/*IsExpected=*/false);
63716370
}
63726371
}
63736372
Value *ValueCompare =
63746373
Builder.CreateICmpEQ(Condition, FirstCase, "switch.selectcmp");
6375-
SelectInst *Ret = cast<SelectInst>(Builder.CreateSelect(
6376-
ValueCompare, ResultVector[0].first, SelectValue, "switch.select"));
6377-
if (HasBranchWeights) {
6374+
Value *Ret = Builder.CreateSelect(ValueCompare, ResultVector[0].first,
6375+
SelectValue, "switch.select");
6376+
if (auto *SI = dyn_cast<SelectInst>(Ret); SI && HasBranchWeights) {
63786377
// We may have had a DefaultResult. Base the position of the first and
63796378
// second's branch weights accordingly. Also the proability that Condition
63806379
// != FirstCase needs to take that into account.
63816380
assert(BranchWeights.size() >= 2);
63826381
size_t FirstCasePos = (Condition != nullptr);
63836382
size_t SecondCasePos = FirstCasePos + 1;
63846383
uint32_t DefaultCase = (Condition != nullptr) ? BranchWeights[0] : 0;
6385-
setBranchWeights(Ret, BranchWeights[FirstCasePos],
6384+
setBranchWeights(SI, BranchWeights[FirstCasePos],
63866385
DefaultCase + BranchWeights[SecondCasePos],
63876386
/*IsExpected=*/false);
63886387
}
@@ -6422,13 +6421,13 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
64226421
Value *And = Builder.CreateAnd(Condition, AndMask);
64236422
Value *Cmp = Builder.CreateICmpEQ(
64246423
And, Constant::getIntegerValue(And->getType(), AndMask));
6425-
SelectInst *Ret = cast<SelectInst>(
6426-
Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult));
6427-
if (HasBranchWeights) {
6424+
Value *Ret =
6425+
Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult);
6426+
if (auto *SI = dyn_cast<SelectInst>(Ret); SI && HasBranchWeights) {
64286427
// We know there's a Default case. We base the resulting branch
64296428
// weights off its probability.
64306429
assert(BranchWeights.size() >= 2);
6431-
setBranchWeights(Ret, accumulate(drop_begin(BranchWeights), 0),
6430+
setBranchWeights(SI, accumulate(drop_begin(BranchWeights), 0),
64326431
BranchWeights[0], /*IsExpected=*/false);
64336432
}
64346433
return Ret;
@@ -6448,11 +6447,11 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
64486447
Value *And = Builder.CreateAnd(Condition, ~BitMask, "switch.and");
64496448
Value *Cmp = Builder.CreateICmpEQ(
64506449
And, Constant::getNullValue(And->getType()), "switch.selectcmp");
6451-
SelectInst *Ret = cast<SelectInst>(
6452-
Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult));
6453-
if (HasBranchWeights) {
6450+
Value *Ret =
6451+
Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult);
6452+
if (auto *SI = dyn_cast<SelectInst>(Ret); SI && HasBranchWeights) {
64546453
assert(BranchWeights.size() >= 2);
6455-
setBranchWeights(Ret, accumulate(drop_begin(BranchWeights), 0),
6454+
setBranchWeights(SI, accumulate(drop_begin(BranchWeights), 0),
64566455
BranchWeights[0], /*IsExpected=*/false);
64576456
}
64586457
return Ret;
@@ -6466,11 +6465,11 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
64666465
Value *Cmp2 = Builder.CreateICmpEQ(Condition, CaseValues[1],
64676466
"switch.selectcmp.case2");
64686467
Value *Cmp = Builder.CreateOr(Cmp1, Cmp2, "switch.selectcmp");
6469-
SelectInst *Ret = cast<SelectInst>(
6470-
Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult));
6471-
if (HasBranchWeights) {
6468+
Value *Ret =
6469+
Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult);
6470+
if (auto *SI = dyn_cast<SelectInst>(Ret); SI && HasBranchWeights) {
64726471
assert(BranchWeights.size() >= 2);
6473-
setBranchWeights(Ret, accumulate(drop_begin(BranchWeights), 0),
6472+
setBranchWeights(SI, accumulate(drop_begin(BranchWeights), 0),
64746473
BranchWeights[0], /*IsExpected=*/false);
64756474
}
64766475
return Ret;

llvm/test/Transforms/SimplifyCFG/switch-to-select-two-case.ll

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,25 @@ bb3:
755755
ret i1 %phi
756756
}
757757

758+
define i32 @negative_constfold_select() {
759+
; CHECK-LABEL: @negative_constfold_select(
760+
; CHECK-NEXT: entry:
761+
; CHECK-NEXT: ret i32 poison
762+
;
763+
entry:
764+
switch i32 poison, label %default [
765+
i32 0, label %bb
766+
i32 2, label %bb
767+
]
768+
769+
bb:
770+
br label %default
771+
772+
default:
773+
%ret = phi i32 [ poison, %entry ], [ poison, %bb ]
774+
ret i32 %ret
775+
}
776+
758777
!0 = !{!"function_entry_count", i64 1000}
759778
!1 = !{!"branch_weights", i32 3, i32 5, i32 7}
760779
!2 = !{!"branch_weights", i32 3, i32 5, i32 7, i32 11, i32 13}

0 commit comments

Comments
 (0)