Skip to content

Commit 352bc16

Browse files
dianqkaokblast
authored andcommitted
[SimplifyCFG] Fold the contiguous wrapping cases into ICmp. (llvm#161000)
Fixes llvm#157113. Take the following IR as an example; we know the destination of the `[1, 3]` cases is `%else`. ```llvm define i32 @src(i8 range(i8 0, 6) %arg) { switch i8 %arg, label %else [ i8 0, label %if i8 4, label %if i8 5, label %if ] if: ret i32 0 else: ret i32 1 } ``` We can first try the non-wrapping range for both destinations, but I don't see how that would be any better. Proof: https://alive2.llvm.org/ce/z/acdWD4.
1 parent fe31442 commit 352bc16

File tree

4 files changed

+331
-47
lines changed

4 files changed

+331
-47
lines changed

llvm/lib/Transforms/Utils/SimplifyCFG.cpp

Lines changed: 112 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5734,15 +5734,66 @@ bool SimplifyCFGOpt::simplifyUnreachable(UnreachableInst *UI) {
57345734
return Changed;
57355735
}
57365736

5737-
static bool casesAreContiguous(SmallVectorImpl<ConstantInt *> &Cases) {
5737+
struct ContiguousCasesResult {
5738+
ConstantInt *Min;
5739+
ConstantInt *Max;
5740+
BasicBlock *Dest;
5741+
BasicBlock *OtherDest;
5742+
SmallVectorImpl<ConstantInt *> *Cases;
5743+
SmallVectorImpl<ConstantInt *> *OtherCases;
5744+
};
5745+
5746+
static std::optional<ContiguousCasesResult>
5747+
findContiguousCases(Value *Condition, SmallVectorImpl<ConstantInt *> &Cases,
5748+
SmallVectorImpl<ConstantInt *> &OtherCases,
5749+
BasicBlock *Dest, BasicBlock *OtherDest) {
57385750
assert(Cases.size() >= 1);
57395751

57405752
array_pod_sort(Cases.begin(), Cases.end(), constantIntSortPredicate);
5741-
for (size_t I = 1, E = Cases.size(); I != E; ++I) {
5742-
if (Cases[I - 1]->getValue() != Cases[I]->getValue() + 1)
5743-
return false;
5753+
const APInt &Min = Cases.back()->getValue();
5754+
const APInt &Max = Cases.front()->getValue();
5755+
APInt Offset = Max - Min;
5756+
size_t ContiguousOffset = Cases.size() - 1;
5757+
if (Offset == ContiguousOffset) {
5758+
return ContiguousCasesResult{
5759+
/*Min=*/Cases.back(),
5760+
/*Max=*/Cases.front(),
5761+
/*Dest=*/Dest,
5762+
/*OtherDest=*/OtherDest,
5763+
/*Cases=*/&Cases,
5764+
/*OtherCases=*/&OtherCases,
5765+
};
57445766
}
5745-
return true;
5767+
ConstantRange CR = computeConstantRange(Condition, /*ForSigned=*/false);
5768+
// If this is a wrapping contiguous range, that is, [Min, OtherMin] +
5769+
// [OtherMax, Max] (also [OtherMax, OtherMin]), [OtherMin+1, OtherMax-1] is a
5770+
// contiguous range for the other destination. N.B. If CR is not a full range,
5771+
// Max+1 is not equal to Min. It's not continuous in arithmetic.
5772+
if (Max == CR.getUnsignedMax() && Min == CR.getUnsignedMin()) {
5773+
assert(Cases.size() >= 2);
5774+
auto *It =
5775+
std::adjacent_find(Cases.begin(), Cases.end(), [](auto L, auto R) {
5776+
return L->getValue() != R->getValue() + 1;
5777+
});
5778+
if (It == Cases.end())
5779+
return std::nullopt;
5780+
auto [OtherMax, OtherMin] = std::make_pair(*It, *std::next(It));
5781+
if ((Max - OtherMax->getValue()) + (OtherMin->getValue() - Min) ==
5782+
Cases.size() - 2) {
5783+
return ContiguousCasesResult{
5784+
/*Min=*/cast<ConstantInt>(
5785+
ConstantInt::get(OtherMin->getType(), OtherMin->getValue() + 1)),
5786+
/*Max=*/
5787+
cast<ConstantInt>(
5788+
ConstantInt::get(OtherMax->getType(), OtherMax->getValue() - 1)),
5789+
/*Dest=*/OtherDest,
5790+
/*OtherDest=*/Dest,
5791+
/*Cases=*/&OtherCases,
5792+
/*OtherCases=*/&Cases,
5793+
};
5794+
}
5795+
}
5796+
return std::nullopt;
57465797
}
57475798

57485799
static void createUnreachableSwitchDefault(SwitchInst *Switch,
@@ -5779,7 +5830,6 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
57795830
bool HasDefault = !SI->defaultDestUnreachable();
57805831

57815832
auto *BB = SI->getParent();
5782-
57835833
// Partition the cases into two sets with different destinations.
57845834
BasicBlock *DestA = HasDefault ? SI->getDefaultDest() : nullptr;
57855835
BasicBlock *DestB = nullptr;
@@ -5813,37 +5863,62 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
58135863
assert(!CasesA.empty() || HasDefault);
58145864

58155865
// Figure out if one of the sets of cases form a contiguous range.
5816-
SmallVectorImpl<ConstantInt *> *ContiguousCases = nullptr;
5817-
BasicBlock *ContiguousDest = nullptr;
5818-
BasicBlock *OtherDest = nullptr;
5819-
if (!CasesA.empty() && casesAreContiguous(CasesA)) {
5820-
ContiguousCases = &CasesA;
5821-
ContiguousDest = DestA;
5822-
OtherDest = DestB;
5823-
} else if (casesAreContiguous(CasesB)) {
5824-
ContiguousCases = &CasesB;
5825-
ContiguousDest = DestB;
5826-
OtherDest = DestA;
5827-
} else
5828-
return false;
5866+
std::optional<ContiguousCasesResult> ContiguousCases;
5867+
5868+
// Only one icmp is needed when there is only one case.
5869+
if (!HasDefault && CasesA.size() == 1)
5870+
ContiguousCases = ContiguousCasesResult{
5871+
/*Min=*/CasesA[0],
5872+
/*Max=*/CasesA[0],
5873+
/*Dest=*/DestA,
5874+
/*OtherDest=*/DestB,
5875+
/*Cases=*/&CasesA,
5876+
/*OtherCases=*/&CasesB,
5877+
};
5878+
else if (CasesB.size() == 1)
5879+
ContiguousCases = ContiguousCasesResult{
5880+
/*Min=*/CasesB[0],
5881+
/*Max=*/CasesB[0],
5882+
/*Dest=*/DestB,
5883+
/*OtherDest=*/DestA,
5884+
/*Cases=*/&CasesB,
5885+
/*OtherCases=*/&CasesA,
5886+
};
5887+
// Correctness: Cases to the default destination cannot be contiguous cases.
5888+
else if (!HasDefault)
5889+
ContiguousCases =
5890+
findContiguousCases(SI->getCondition(), CasesA, CasesB, DestA, DestB);
58295891

5830-
// Start building the compare and branch.
5892+
if (!ContiguousCases)
5893+
ContiguousCases =
5894+
findContiguousCases(SI->getCondition(), CasesB, CasesA, DestB, DestA);
58315895

5832-
Constant *Offset = ConstantExpr::getNeg(ContiguousCases->back());
5833-
Constant *NumCases =
5834-
ConstantInt::get(Offset->getType(), ContiguousCases->size());
5896+
if (!ContiguousCases)
5897+
return false;
58355898

5836-
Value *Sub = SI->getCondition();
5837-
if (!Offset->isNullValue())
5838-
Sub = Builder.CreateAdd(Sub, Offset, Sub->getName() + ".off");
5899+
auto [Min, Max, Dest, OtherDest, Cases, OtherCases] = *ContiguousCases;
58395900

5840-
Value *Cmp;
5901+
// Start building the compare and branch.
5902+
5903+
Constant *Offset = ConstantExpr::getNeg(Min);
5904+
Constant *NumCases = ConstantInt::get(Offset->getType(),
5905+
Max->getValue() - Min->getValue() + 1);
5906+
BranchInst *NewBI;
5907+
if (NumCases->isOneValue()) {
5908+
assert(Max->getValue() == Min->getValue());
5909+
Value *Cmp = Builder.CreateICmpEQ(SI->getCondition(), Min);
5910+
NewBI = Builder.CreateCondBr(Cmp, Dest, OtherDest);
5911+
}
58415912
// If NumCases overflowed, then all possible values jump to the successor.
5842-
if (NumCases->isNullValue() && !ContiguousCases->empty())
5843-
Cmp = ConstantInt::getTrue(SI->getContext());
5844-
else
5845-
Cmp = Builder.CreateICmpULT(Sub, NumCases, "switch");
5846-
BranchInst *NewBI = Builder.CreateCondBr(Cmp, ContiguousDest, OtherDest);
5913+
else if (NumCases->isNullValue() && !Cases->empty()) {
5914+
NewBI = Builder.CreateBr(Dest);
5915+
} else {
5916+
Value *Sub = SI->getCondition();
5917+
if (!Offset->isNullValue())
5918+
Sub = Builder.CreateAdd(Sub, Offset, Sub->getName() + ".off");
5919+
Value *Cmp = Builder.CreateICmpULT(Sub, NumCases, "switch");
5920+
NewBI = Builder.CreateCondBr(Cmp, Dest, OtherDest);
5921+
}
58475922

58485923
// Update weight for the newly-created conditional branch.
58495924
if (hasBranchWeightMD(*SI)) {
@@ -5853,7 +5928,7 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
58535928
uint64_t TrueWeight = 0;
58545929
uint64_t FalseWeight = 0;
58555930
for (size_t I = 0, E = Weights.size(); I != E; ++I) {
5856-
if (SI->getSuccessor(I) == ContiguousDest)
5931+
if (SI->getSuccessor(I) == Dest)
58575932
TrueWeight += Weights[I];
58585933
else
58595934
FalseWeight += Weights[I];
@@ -5868,15 +5943,15 @@ bool SimplifyCFGOpt::turnSwitchRangeIntoICmp(SwitchInst *SI,
58685943
}
58695944

58705945
// Prune obsolete incoming values off the successors' PHI nodes.
5871-
for (auto BBI = ContiguousDest->begin(); isa<PHINode>(BBI); ++BBI) {
5872-
unsigned PreviousEdges = ContiguousCases->size();
5873-
if (ContiguousDest == SI->getDefaultDest())
5946+
for (auto BBI = Dest->begin(); isa<PHINode>(BBI); ++BBI) {
5947+
unsigned PreviousEdges = Cases->size();
5948+
if (Dest == SI->getDefaultDest())
58745949
++PreviousEdges;
58755950
for (unsigned I = 0, E = PreviousEdges - 1; I != E; ++I)
58765951
cast<PHINode>(BBI)->removeIncomingValue(SI->getParent());
58775952
}
58785953
for (auto BBI = OtherDest->begin(); isa<PHINode>(BBI); ++BBI) {
5879-
unsigned PreviousEdges = SI->getNumCases() - ContiguousCases->size();
5954+
unsigned PreviousEdges = OtherCases->size();
58805955
if (OtherDest == SI->getDefaultDest())
58815956
++PreviousEdges;
58825957
for (unsigned I = 0, E = PreviousEdges - 1; I != E; ++I)

llvm/test/Transforms/Coroutines/coro-catchswitch-cleanuppad.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ cleanup2:
8080
; CHECK: cleanup2.corodispatch:
8181
; CHECK: %1 = phi i8 [ 0, %handler2 ], [ 1, %catch.dispatch.2 ]
8282
; CHECK: %2 = cleanuppad within %h1 []
83-
; CHECK: %switch = icmp ult i8 %1, 1
84-
; CHECK: br i1 %switch, label %cleanup2.from.handler2, label %cleanup2.from.catch.dispatch.2
83+
; CHECK: %3 = icmp eq i8 %1, 0
84+
; CHECK: br i1 %3, label %cleanup2.from.handler2, label %cleanup2.from.catch.dispatch.2
8585

8686
; CHECK: cleanup2.from.handler2:
8787
; CHECK: %valueB.reload = load i32, ptr %valueB.spill.addr, align 4

llvm/test/Transforms/SimplifyCFG/switch-dead-default.ll

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@ declare void @foo(i32)
77
define void @test(i1 %a) {
88
; CHECK-LABEL: define void @test(
99
; CHECK-SAME: i1 [[A:%.*]]) {
10-
; CHECK-NEXT: [[A_OFF:%.*]] = add i1 [[A]], true
11-
; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i1 [[A_OFF]], true
10+
; CHECK-NEXT: [[SWITCH:%.*]] = icmp eq i1 [[A]], true
1211
; CHECK-NEXT: br i1 [[SWITCH]], label [[TRUE:%.*]], label [[FALSE:%.*]]
1312
; CHECK: common.ret:
1413
; CHECK-NEXT: ret void
@@ -209,8 +208,7 @@ define void @test5(i8 %a) {
209208
; CHECK-SAME: i8 [[A:%.*]]) {
210209
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[A]], 2
211210
; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]])
212-
; CHECK-NEXT: [[A_OFF:%.*]] = add i8 [[A]], -1
213-
; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i8 [[A_OFF]], 1
211+
; CHECK-NEXT: [[SWITCH:%.*]] = icmp eq i8 [[A]], 1
214212
; CHECK-NEXT: br i1 [[SWITCH]], label [[TRUE:%.*]], label [[FALSE:%.*]]
215213
; CHECK: common.ret:
216214
; CHECK-NEXT: ret void
@@ -243,8 +241,7 @@ define void @test6(i8 %a) {
243241
; CHECK-NEXT: [[AND:%.*]] = and i8 [[A]], -2
244242
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[AND]], -2
245243
; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]])
246-
; CHECK-NEXT: [[A_OFF:%.*]] = add i8 [[A]], 1
247-
; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i8 [[A_OFF]], 1
244+
; CHECK-NEXT: [[SWITCH:%.*]] = icmp eq i8 [[A]], -1
248245
; CHECK-NEXT: br i1 [[SWITCH]], label [[TRUE:%.*]], label [[FALSE:%.*]]
249246
; CHECK: common.ret:
250247
; CHECK-NEXT: ret void
@@ -279,8 +276,7 @@ define void @test7(i8 %a) {
279276
; CHECK-NEXT: [[AND:%.*]] = and i8 [[A]], -2
280277
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[AND]], -2
281278
; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]])
282-
; CHECK-NEXT: [[A_OFF:%.*]] = add i8 [[A]], 1
283-
; CHECK-NEXT: [[SWITCH:%.*]] = icmp ult i8 [[A_OFF]], 1
279+
; CHECK-NEXT: [[SWITCH:%.*]] = icmp eq i8 [[A]], -1
284280
; CHECK-NEXT: br i1 [[SWITCH]], label [[TRUE:%.*]], label [[FALSE:%.*]]
285281
; CHECK: common.ret:
286282
; CHECK-NEXT: ret void

0 commit comments

Comments
 (0)