Skip to content

Commit 5b2f9b5

Browse files
authored
[SimplifyCFG]: Switch on umin replaces default (#164097)
A switch on `umin` can eliminate the default case by making the `umin`'s constant the default case. Proof: https://alive2.llvm.org/ce/z/_N6nfs Fixes: #162111
1 parent f7be258 commit 5b2f9b5

File tree

4 files changed

+339
-0
lines changed

4 files changed

+339
-0
lines changed

llvm/include/llvm/IR/Instructions.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3556,6 +3556,11 @@ class SwitchInstProfUpdateWrapper {
35563556
/// correspondent branch weight.
35573557
LLVM_ABI SwitchInst::CaseIt removeCase(SwitchInst::CaseIt I);
35583558

3559+
/// Replace the default destination by given case. Delegate the call to
3560+
/// the underlying SwitchInst::setDefaultDest and remove correspondent branch
3561+
/// weight.
3562+
LLVM_ABI void replaceDefaultDest(SwitchInst::CaseIt I);
3563+
35593564
/// Delegate the call to the underlying SwitchInst::addCase() and set the
35603565
/// specified branch weight for the added case.
35613566
LLVM_ABI void addCase(ConstantInt *OnVal, BasicBlock *Dest, CaseWeightOpt W);

llvm/lib/IR/Instructions.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4171,6 +4171,16 @@ SwitchInstProfUpdateWrapper::removeCase(SwitchInst::CaseIt I) {
41714171
return SI.removeCase(I);
41724172
}
41734173

4174+
void SwitchInstProfUpdateWrapper::replaceDefaultDest(SwitchInst::CaseIt I) {
4175+
auto *DestBlock = I->getCaseSuccessor();
4176+
if (Weights) {
4177+
auto Weight = getSuccessorWeight(I->getCaseIndex() + 1);
4178+
(*Weights)[0] = Weight.value();
4179+
}
4180+
4181+
SI.setDefaultDest(DestBlock);
4182+
}
4183+
41744184
void SwitchInstProfUpdateWrapper::addCase(
41754185
ConstantInt *OnVal, BasicBlock *Dest,
41764186
SwitchInstProfUpdateWrapper::CaseWeightOpt W) {

llvm/lib/Transforms/Utils/SimplifyCFG.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7570,6 +7570,81 @@ static bool reduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder,
75707570
return true;
75717571
}
75727572

7573+
/// Tries to transform the switch when the condition is umin with a constant.
7574+
/// In that case, the default branch can be replaced by the constant's branch.
7575+
/// This method also removes dead cases when the simplification cannot replace
7576+
/// the default branch.
7577+
///
7578+
/// For example:
7579+
/// switch(umin(a, 3)) {
7580+
/// case 0:
7581+
/// case 1:
7582+
/// case 2:
7583+
/// case 3:
7584+
/// case 4:
7585+
/// // ...
7586+
/// default:
7587+
/// unreachable
7588+
/// }
7589+
///
7590+
/// Transforms into:
7591+
///
7592+
/// switch(a) {
7593+
/// case 0:
7594+
/// case 1:
7595+
/// case 2:
7596+
/// default:
7597+
/// // This is case 3
7598+
/// }
7599+
static bool simplifySwitchWhenUMin(SwitchInst *SI, DomTreeUpdater *DTU) {
7600+
Value *A;
7601+
ConstantInt *Constant;
7602+
7603+
if (!match(SI->getCondition(), m_UMin(m_Value(A), m_ConstantInt(Constant))))
7604+
return false;
7605+
7606+
SmallVector<DominatorTree::UpdateType> Updates;
7607+
SwitchInstProfUpdateWrapper SIW(*SI);
7608+
BasicBlock *BB = SIW->getParent();
7609+
7610+
// Dead cases are removed even when the simplification fails.
7611+
// A case is dead when its value is higher than the Constant.
7612+
for (auto I = SI->case_begin(), E = SI->case_end(); I != E;) {
7613+
if (!I->getCaseValue()->getValue().ugt(Constant->getValue())) {
7614+
++I;
7615+
continue;
7616+
}
7617+
BasicBlock *DeadCaseBB = I->getCaseSuccessor();
7618+
DeadCaseBB->removePredecessor(BB);
7619+
Updates.push_back({DominatorTree::Delete, BB, DeadCaseBB});
7620+
I = SIW->removeCase(I);
7621+
E = SIW->case_end();
7622+
}
7623+
7624+
auto Case = SI->findCaseValue(Constant);
7625+
// If the case value is not found, `findCaseValue` returns the default case.
7626+
// In this scenario, since there is no explicit `case 3:`, the simplification
7627+
// fails. The simplification also fails when the switch’s default destination
7628+
// is reachable.
7629+
if (!SI->defaultDestUnreachable() || Case == SI->case_default()) {
7630+
if (DTU)
7631+
DTU->applyUpdates(Updates);
7632+
return !Updates.empty();
7633+
}
7634+
7635+
BasicBlock *Unreachable = SI->getDefaultDest();
7636+
SIW.replaceDefaultDest(Case);
7637+
SIW.removeCase(Case);
7638+
SIW->setCondition(A);
7639+
7640+
Updates.push_back({DominatorTree::Delete, BB, Unreachable});
7641+
7642+
if (DTU)
7643+
DTU->applyUpdates(Updates);
7644+
7645+
return true;
7646+
}
7647+
75737648
/// Tries to transform switch of powers of two to reduce switch range.
75747649
/// For example, switch like:
75757650
/// switch (C) { case 1: case 2: case 64: case 128: }
@@ -8037,6 +8112,9 @@ bool SimplifyCFGOpt::simplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) {
80378112
if (simplifyDuplicateSwitchArms(SI, DTU))
80388113
return requestResimplify();
80398114

8115+
if (simplifySwitchWhenUMin(SI, DTU))
8116+
return requestResimplify();
8117+
80408118
return false;
80418119
}
80428120

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 6
2+
; RUN: opt -S -passes=simplifycfg < %s | FileCheck %s
3+
4+
declare void @a()
5+
declare void @b()
6+
declare void @c()
7+
declare void @d()
8+
9+
define void @switch_replace_default(i32 %x) {
10+
; CHECK-LABEL: define void @switch_replace_default(
11+
; CHECK-SAME: i32 [[X:%.*]]) {
12+
; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 3)
13+
; CHECK-NEXT: switch i32 [[X]], label %[[COMMON_RET:.*]] [
14+
; CHECK-NEXT: i32 0, label %[[CASE0:.*]]
15+
; CHECK-NEXT: i32 1, label %[[CASE1:.*]]
16+
; CHECK-NEXT: i32 2, label %[[CASE2:.*]]
17+
; CHECK-NEXT: ], !prof [[PROF0:![0-9]+]]
18+
; CHECK: [[COMMON_RET]]:
19+
; CHECK-NEXT: ret void
20+
; CHECK: [[CASE0]]:
21+
; CHECK-NEXT: call void @a()
22+
; CHECK-NEXT: br label %[[COMMON_RET]]
23+
; CHECK: [[CASE1]]:
24+
; CHECK-NEXT: call void @b()
25+
; CHECK-NEXT: br label %[[COMMON_RET]]
26+
; CHECK: [[CASE2]]:
27+
; CHECK-NEXT: call void @c()
28+
; CHECK-NEXT: br label %[[COMMON_RET]]
29+
;
30+
%min = call i32 @llvm.umin.i32(i32 %x, i32 3)
31+
switch i32 %min, label %unreachable [
32+
i32 0, label %case0
33+
i32 1, label %case1
34+
i32 2, label %case2
35+
i32 3, label %case3
36+
], !prof !0
37+
38+
case0:
39+
call void @a()
40+
ret void
41+
42+
case1:
43+
call void @b()
44+
ret void
45+
46+
case2:
47+
call void @c()
48+
ret void
49+
50+
case3:
51+
ret void
52+
53+
unreachable:
54+
unreachable
55+
}
56+
57+
define void @switch_replace_default_and_remove_dead_cases(i32 %x) {
58+
; CHECK-LABEL: define void @switch_replace_default_and_remove_dead_cases(
59+
; CHECK-SAME: i32 [[X:%.*]]) {
60+
; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 3)
61+
; CHECK-NEXT: switch i32 [[X]], label %[[COMMON_RET:.*]] [
62+
; CHECK-NEXT: i32 2, label %[[CASE2:.*]]
63+
; CHECK-NEXT: i32 1, label %[[CASE1:.*]]
64+
; CHECK-NEXT: ]
65+
; CHECK: [[COMMON_RET]]:
66+
; CHECK-NEXT: ret void
67+
; CHECK: [[CASE1]]:
68+
; CHECK-NEXT: call void @b()
69+
; CHECK-NEXT: br label %[[COMMON_RET]]
70+
; CHECK: [[CASE2]]:
71+
; CHECK-NEXT: call void @c()
72+
; CHECK-NEXT: br label %[[COMMON_RET]]
73+
;
74+
%min = call i32 @llvm.umin.i32(i32 %x, i32 3)
75+
switch i32 %min, label %unreachable [
76+
i32 4, label %case4
77+
i32 1, label %case1
78+
i32 2, label %case2
79+
i32 3, label %case3
80+
]
81+
82+
case4:
83+
call void @a()
84+
ret void
85+
86+
case1:
87+
call void @b()
88+
ret void
89+
90+
case2:
91+
call void @c()
92+
ret void
93+
94+
case3:
95+
ret void
96+
97+
unreachable:
98+
unreachable
99+
}
100+
101+
define void @switch_replace_default_when_holes(i32 %x) {
102+
; CHECK-LABEL: define void @switch_replace_default_when_holes(
103+
; CHECK-SAME: i32 [[X:%.*]]) {
104+
; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 3)
105+
; CHECK-NEXT: switch i32 [[X]], label %[[COMMON_RET:.*]] [
106+
; CHECK-NEXT: i32 1, label %[[CASE1:.*]]
107+
; CHECK-NEXT: i32 2, label %[[CASE2:.*]]
108+
; CHECK-NEXT: ]
109+
; CHECK: [[COMMON_RET]]:
110+
; CHECK-NEXT: ret void
111+
; CHECK: [[CASE1]]:
112+
; CHECK-NEXT: call void @b()
113+
; CHECK-NEXT: br label %[[COMMON_RET]]
114+
; CHECK: [[CASE2]]:
115+
; CHECK-NEXT: call void @c()
116+
; CHECK-NEXT: br label %[[COMMON_RET]]
117+
;
118+
%min = call i32 @llvm.umin.i32(i32 %x, i32 3)
119+
switch i32 %min, label %unreachable [
120+
i32 1, label %case1
121+
i32 2, label %case2
122+
i32 3, label %case3
123+
]
124+
125+
case1:
126+
call void @b()
127+
ret void
128+
129+
case2:
130+
call void @c()
131+
ret void
132+
133+
case3:
134+
ret void
135+
136+
unreachable:
137+
unreachable
138+
}
139+
140+
define void @do_not_switch_replace_default(i32 %x, i32 %y) {
141+
; CHECK-LABEL: define void @do_not_switch_replace_default(
142+
; CHECK-SAME: i32 [[X:%.*]], i32 [[Y:%.*]]) {
143+
; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 [[Y]])
144+
; CHECK-NEXT: switch i32 [[MIN]], label %[[UNREACHABLE:.*]] [
145+
; CHECK-NEXT: i32 0, label %[[CASE0:.*]]
146+
; CHECK-NEXT: i32 1, label %[[CASE1:.*]]
147+
; CHECK-NEXT: i32 2, label %[[CASE2:.*]]
148+
; CHECK-NEXT: i32 3, label %[[COMMON_RET:.*]]
149+
; CHECK-NEXT: ]
150+
; CHECK: [[COMMON_RET]]:
151+
; CHECK-NEXT: ret void
152+
; CHECK: [[CASE0]]:
153+
; CHECK-NEXT: call void @a()
154+
; CHECK-NEXT: br label %[[COMMON_RET]]
155+
; CHECK: [[CASE1]]:
156+
; CHECK-NEXT: call void @b()
157+
; CHECK-NEXT: br label %[[COMMON_RET]]
158+
; CHECK: [[CASE2]]:
159+
; CHECK-NEXT: call void @c()
160+
; CHECK-NEXT: br label %[[COMMON_RET]]
161+
; CHECK: [[UNREACHABLE]]:
162+
; CHECK-NEXT: unreachable
163+
;
164+
%min = call i32 @llvm.umin.i32(i32 %x, i32 %y)
165+
switch i32 %min, label %unreachable [
166+
i32 0, label %case0
167+
i32 1, label %case1
168+
i32 2, label %case2
169+
i32 3, label %case3
170+
]
171+
172+
case0:
173+
call void @a()
174+
ret void
175+
176+
case1:
177+
call void @b()
178+
ret void
179+
180+
case2:
181+
call void @c()
182+
ret void
183+
184+
case3:
185+
ret void
186+
187+
unreachable:
188+
unreachable
189+
}
190+
191+
define void @do_not_replace_switch_default_but_remove_dead_cases(i32 %x) {
192+
; CHECK-LABEL: define void @do_not_replace_switch_default_but_remove_dead_cases(
193+
; CHECK-SAME: i32 [[X:%.*]]) {
194+
; CHECK-NEXT: [[MIN:%.*]] = call i32 @llvm.umin.i32(i32 [[X]], i32 3)
195+
; CHECK-NEXT: switch i32 [[MIN]], label %[[CASE0:.*]] [
196+
; CHECK-NEXT: i32 3, label %[[COMMON_RET:.*]]
197+
; CHECK-NEXT: i32 1, label %[[CASE1:.*]]
198+
; CHECK-NEXT: i32 2, label %[[CASE2:.*]]
199+
; CHECK-NEXT: ]
200+
; CHECK: [[COMMON_RET]]:
201+
; CHECK-NEXT: ret void
202+
; CHECK: [[CASE0]]:
203+
; CHECK-NEXT: call void @a()
204+
; CHECK-NEXT: br label %[[COMMON_RET]]
205+
; CHECK: [[CASE1]]:
206+
; CHECK-NEXT: call void @b()
207+
; CHECK-NEXT: br label %[[COMMON_RET]]
208+
; CHECK: [[CASE2]]:
209+
; CHECK-NEXT: call void @c()
210+
; CHECK-NEXT: br label %[[COMMON_RET]]
211+
;
212+
%min = call i32 @llvm.umin.i32(i32 %x, i32 3)
213+
switch i32 %min, label %case0 [ ; default is reachable, therefore simplification not triggered
214+
i32 0, label %case0
215+
i32 1, label %case1
216+
i32 2, label %case2
217+
i32 3, label %case3
218+
i32 4, label %case4
219+
]
220+
221+
case0:
222+
call void @a()
223+
ret void
224+
225+
case1:
226+
call void @b()
227+
ret void
228+
229+
case2:
230+
call void @c()
231+
ret void
232+
233+
case3:
234+
ret void
235+
236+
case4:
237+
call void @d()
238+
ret void
239+
240+
}
241+
242+
243+
!0 = !{!"branch_weights", i32 1, i32 2, i32 3, i32 99, i32 5}
244+
;.
245+
; CHECK: [[PROF0]] = !{!"branch_weights", i32 5, i32 2, i32 3, i32 99}
246+
;.

0 commit comments

Comments
 (0)