Skip to content

Commit 7ca9e24

Browse files
authored
Switch optimizations in remove-unused-brs (#1753)
* Switch optimizations in remove-unused-brs: thread switch jumps, and turn a switch with all identical targets into a br * refinalize in interm operations in remove-unused-brs, as we can be confused by it
1 parent 801ff52 commit 7ca9e24

13 files changed

+307
-172
lines changed

src/ir/branch-utils.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,40 @@ inline bool isBranchReachable(Expression* expr) {
4545
WASM_UNREACHABLE();
4646
}
4747

48+
inline std::set<Name> getUniqueTargets(Switch* sw) {
49+
std::set<Name> ret;
50+
for (auto target : sw->targets) {
51+
ret.insert(target);
52+
}
53+
ret.insert(sw->default_);
54+
return ret;
55+
}
56+
57+
// If we branch to 'from', change that to 'to' instead.
58+
inline bool replacePossibleTarget(Expression* branch, Name from, Name to) {
59+
bool worked = false;
60+
if (auto* br = branch->dynCast<Break>()) {
61+
if (br->name == from) {
62+
br->name = to;
63+
worked = true;
64+
}
65+
} else if (auto* sw = branch->dynCast<Switch>()) {
66+
for (auto& target : sw->targets) {
67+
if (target == from) {
68+
target = to;
69+
worked = true;
70+
}
71+
}
72+
if (sw->default_ == from) {
73+
sw->default_ = to;
74+
worked = true;
75+
}
76+
} else {
77+
WASM_UNREACHABLE();
78+
}
79+
return worked;
80+
}
81+
4882
// returns the set of targets to which we branch that are
4983
// outside of a node
5084
inline std::set<Name> getExitingBranches(Expression* ast) {

src/passes/Flatten.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,9 @@
5353
#include <wasm.h>
5454
#include <pass.h>
5555
#include <wasm-builder.h>
56-
#include <ir/utils.h>
56+
#include <ir/branch-utils.h>
5757
#include <ir/effects.h>
58+
#include <ir/utils.h>
5859

5960
namespace wasm {
6061

@@ -232,11 +233,7 @@ struct Flatten : public WalkerPass<ExpressionStackWalker<Flatten, UnifiedExpress
232233
Index temp = builder.addVar(getFunction(), type);
233234
ourPreludes.push_back(builder.makeSetLocal(temp, sw->value));
234235
// we don't know which break target will be hit - assign to them all
235-
std::set<Name> names;
236-
for (auto target : sw->targets) {
237-
names.insert(target);
238-
}
239-
names.insert(sw->default_);
236+
auto names = BranchUtils::getUniqueTargets(sw);
240237
for (auto name : names) {
241238
ourPreludes.push_back(builder.makeSetLocal(
242239
getTempForBreakTarget(name, type),

src/passes/RemoveUnusedBrs.cpp

Lines changed: 55 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
435435
}
436436
}
437437

438-
void sinkBlocks(Function* func) {
438+
bool sinkBlocks(Function* func) {
439439
struct Sinker : public PostWalker<Sinker> {
440440
bool worked = false;
441441

@@ -501,13 +501,14 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
501501

502502
sinker.doWalkFunction(func);
503503
if (sinker.worked) {
504-
anotherCycle = true;
504+
ReFinalize().walkFunctionInModule(func, getModule());
505+
return true;
505506
}
507+
return false;
506508
}
507509

508510
void doWalkFunction(Function* func) {
509511
// multiple cycles may be needed
510-
bool worked = false;
511512
do {
512513
anotherCycle = false;
513514
super::doWalkFunction(func);
@@ -532,32 +533,39 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
532533
anotherCycle |= optimizeLoop(loop);
533534
}
534535
loops.clear();
536+
if (anotherCycle) {
537+
ReFinalize().walkFunctionInModule(func, getModule());
538+
}
535539
// sink blocks
536-
sinkBlocks(func);
537-
if (anotherCycle) worked = true;
540+
if (sinkBlocks(func)) {
541+
anotherCycle = true;
542+
}
538543
} while (anotherCycle);
539544

540-
if (worked) {
541-
// Our work may alter block and if types, they may now return values that we made flow through them
542-
ReFinalize().walkFunctionInModule(func, getModule());
543-
}
544-
545545
// thread trivial jumps
546546
struct JumpThreader : public ControlFlowWalker<JumpThreader> {
547-
// map of all value-less breaks going to a block (and not a loop)
548-
std::map<Block*, std::vector<Break*>> breaksToBlock;
547+
// map of all value-less breaks and switches going to a block (and not a loop)
548+
std::map<Block*, std::vector<Expression*>> branchesToBlock;
549549

550-
// the names to update
551-
std::map<Break*, Name> newNames;
550+
bool worked = false;
552551

553552
void visitBreak(Break* curr) {
554553
if (!curr->value) {
555554
if (auto* target = findBreakTarget(curr->name)->dynCast<Block>()) {
556-
breaksToBlock[target].push_back(curr);
555+
branchesToBlock[target].push_back(curr);
556+
}
557+
}
558+
}
559+
void visitSwitch(Switch* curr) {
560+
if (!curr->value) {
561+
auto names = BranchUtils::getUniqueTargets(curr);
562+
for (auto name : names) {
563+
if (auto* target = findBreakTarget(name)->dynCast<Block>()) {
564+
branchesToBlock[target].push_back(curr);
565+
}
557566
}
558567
}
559568
}
560-
// TODO: Switch?
561569
void visitBlock(Block* curr) {
562570
auto& list = curr->list;
563571
if (list.size() == 1 && curr->name.is()) {
@@ -566,41 +574,36 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
566574
// the two blocks must have the same type for us to update the branch, as otherwise
567575
// one block may be unreachable and the other concrete, so one might lack a value
568576
if (child->name.is() && child->name != curr->name && child->type == curr->type) {
569-
auto& breaks = breaksToBlock[child];
570-
for (auto* br : breaks) {
571-
newNames[br] = curr->name;
572-
breaksToBlock[curr].push_back(br); // update the list - we may push it even more later
573-
}
574-
breaksToBlock.erase(child);
577+
redirectBranches(child, curr->name);
575578
}
576579
}
577580
} else if (list.size() == 2) {
578581
// if this block has two children, a child-block and a simple jump, then jumps to child-block can be replaced with jumps to the new target
579582
auto* child = list[0]->dynCast<Block>();
580583
auto* jump = list[1]->dynCast<Break>();
581584
if (child && child->name.is() && jump && ExpressionAnalyzer::isSimple(jump)) {
582-
auto& breaks = breaksToBlock[child];
583-
for (auto* br : breaks) {
584-
newNames[br] = jump->name;
585-
}
586-
// if the jump is to another block then we can update the list, and maybe push it even more later
587-
if (auto* newTarget = findBreakTarget(jump->name)->dynCast<Block>()) {
588-
for (auto* br : breaks) {
589-
breaksToBlock[newTarget].push_back(br);
590-
}
591-
}
592-
breaksToBlock.erase(child);
585+
redirectBranches(child, jump->name);
593586
}
594587
}
595588
}
596589

597-
void finish(Function* func) {
598-
for (auto& iter : newNames) {
599-
auto* br = iter.first;
600-
auto name = iter.second;
601-
br->name = name;
590+
void redirectBranches(Block* from, Name to) {
591+
auto& branches = branchesToBlock[from];
592+
for (auto* branch : branches) {
593+
if (BranchUtils::replacePossibleTarget(branch, from->name, to)) {
594+
worked = true;
595+
}
596+
}
597+
// if the jump is to another block then we can update the list, and maybe push it even more later
598+
if (auto* newTarget = findBreakTarget(to)->dynCast<Block>()) {
599+
for (auto* branch : branches) {
600+
branchesToBlock[newTarget].push_back(branch);
601+
}
602602
}
603-
if (newNames.size() > 0) {
603+
}
604+
605+
void finish(Function* func) {
606+
if (worked) {
604607
// by changing where brs go, we may change block types etc.
605608
ReFinalize().walkFunctionInModule(func, getModule());
606609
}
@@ -686,6 +689,19 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
686689
}
687690
}
688691

692+
void visitSwitch(Switch* curr) {
693+
if (BranchUtils::getUniqueTargets(curr).size() == 1) {
694+
// This switch has just one target no matter what; replace with a br.
695+
Builder builder(*getModule());
696+
replaceCurrent(
697+
builder.makeSequence(
698+
builder.makeDrop(curr->condition), // might have side effects
699+
builder.makeBreak(curr->default_, curr->value)
700+
)
701+
);
702+
}
703+
}
704+
689705
// Restructuring of ifs: if we have
690706
// (block $x
691707
// (br_if $x (cond))

src/passes/SimplifyLocals.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
#include <wasm-builder.h>
5151
#include <wasm-traversal.h>
5252
#include <pass.h>
53+
#include <ir/branch-utils.h>
5354
#include <ir/count.h>
5455
#include <ir/effects.h>
5556
#include "ir/equivalent_sets.h"
@@ -128,10 +129,10 @@ struct SimplifyLocals : public WalkerPass<LinearExecutionWalker<SimplifyLocals<a
128129
assert(!curr->cast<If>()->ifFalse); // if-elses are handled by doNoteIfElse* methods
129130
} else if (curr->is<Switch>()) {
130131
auto* sw = curr->cast<Switch>();
131-
for (auto target : sw->targets) {
132+
auto targets = BranchUtils::getUniqueTargets(sw);
133+
for (auto target : targets) {
132134
self->unoptimizableBlocks.insert(target);
133135
}
134-
self->unoptimizableBlocks.insert(sw->default_);
135136
// TODO: we could use this info to stop gathering data on these blocks
136137
}
137138
self->sinkables.clear();

test/emcc_hello_world.fromasm

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3309,19 +3309,19 @@
33093309
(block $__rjti$4
33103310
(block $__rjti$3
33113311
(block $switch-default120
3312-
(block $switch-case42
3312+
(block $switch-case119
33133313
(block $switch-case41
33143314
(block $switch-case40
33153315
(block $switch-case39
33163316
(block $switch-case38
33173317
(block $switch-case37
33183318
(block $switch-case36
3319-
(block $switch-case34
3319+
(block $switch-case35
33203320
(block $switch-case33
3321-
(block $switch-case29
3321+
(block $switch-case30
33223322
(block $switch-case28
33233323
(block $switch-case27
3324-
(br_table $switch-case42 $switch-default120 $switch-case40 $switch-default120 $switch-case42 $switch-case42 $switch-case42 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-case41 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-case29 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-case42 $switch-default120 $switch-case37 $switch-case34 $switch-case42 $switch-case42 $switch-case42 $switch-default120 $switch-case34 $switch-default120 $switch-default120 $switch-default120 $switch-case38 $switch-case27 $switch-case33 $switch-case28 $switch-default120 $switch-default120 $switch-case39 $switch-default120 $switch-case36 $switch-default120 $switch-default120 $switch-case29 $switch-default120
3324+
(br_table $switch-case119 $switch-default120 $switch-case40 $switch-default120 $switch-case119 $switch-case119 $switch-case119 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-case41 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-case30 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-default120 $switch-case119 $switch-default120 $switch-case37 $switch-case35 $switch-case119 $switch-case119 $switch-case119 $switch-default120 $switch-case35 $switch-default120 $switch-default120 $switch-default120 $switch-case38 $switch-case27 $switch-case33 $switch-case28 $switch-default120 $switch-default120 $switch-case39 $switch-default120 $switch-case36 $switch-default120 $switch-default120 $switch-case30 $switch-default120
33253325
(i32.sub
33263326
(tee_local $19
33273327
(select
@@ -6915,7 +6915,7 @@
69156915
(get_local $1)
69166916
(i32.const 20)
69176917
)
6918-
(block $switch-default
6918+
(block $label$break$L1
69196919
(block $switch-case9
69206920
(block $switch-case8
69216921
(block $switch-case7
@@ -6926,7 +6926,7 @@
69266926
(block $switch-case2
69276927
(block $switch-case1
69286928
(block $switch-case
6929-
(br_table $switch-case $switch-case1 $switch-case2 $switch-case3 $switch-case4 $switch-case5 $switch-case6 $switch-case7 $switch-case8 $switch-case9 $switch-default
6929+
(br_table $switch-case $switch-case1 $switch-case2 $switch-case3 $switch-case4 $switch-case5 $switch-case6 $switch-case7 $switch-case8 $switch-case9 $label$break$L1
69306930
(i32.sub
69316931
(get_local $1)
69326932
(i32.const 9)
@@ -6959,7 +6959,7 @@
69596959
(get_local $0)
69606960
(get_local $3)
69616961
)
6962-
(br $switch-default)
6962+
(br $label$break$L1)
69636963
)
69646964
(set_local $1
69656965
(i32.load
@@ -7000,7 +7000,7 @@
70007000
(i32.const 31)
70017001
)
70027002
)
7003-
(br $switch-default)
7003+
(br $label$break$L1)
70047004
)
70057005
(set_local $3
70067006
(i32.load
@@ -7032,7 +7032,7 @@
70327032
(get_local $0)
70337033
(i32.const 0)
70347034
)
7035-
(br $switch-default)
7035+
(br $label$break$L1)
70367036
)
70377037
(set_local $5
70387038
(i32.load
@@ -7071,7 +7071,7 @@
70717071
(get_local $0)
70727072
(get_local $3)
70737073
)
7074-
(br $switch-default)
7074+
(br $label$break$L1)
70757075
)
70767076
(set_local $3
70777077
(i32.load
@@ -7123,7 +7123,7 @@
71237123
(i32.const 31)
71247124
)
71257125
)
7126-
(br $switch-default)
7126+
(br $label$break$L1)
71277127
)
71287128
(set_local $3
71297129
(i32.load
@@ -7158,7 +7158,7 @@
71587158
(get_local $0)
71597159
(i32.const 0)
71607160
)
7161-
(br $switch-default)
7161+
(br $label$break$L1)
71627162
)
71637163
(set_local $3
71647164
(i32.load
@@ -7210,7 +7210,7 @@
72107210
(i32.const 31)
72117211
)
72127212
)
7213-
(br $switch-default)
7213+
(br $label$break$L1)
72147214
)
72157215
(set_local $3
72167216
(i32.load
@@ -7245,7 +7245,7 @@
72457245
(get_local $0)
72467246
(i32.const 0)
72477247
)
7248-
(br $switch-default)
7248+
(br $label$break$L1)
72497249
)
72507250
(set_local $4
72517251
(f64.load
@@ -7273,7 +7273,7 @@
72737273
(get_local $0)
72747274
(get_local $4)
72757275
)
7276-
(br $switch-default)
7276+
(br $label$break$L1)
72777277
)
72787278
(set_local $4
72797279
(f64.load

0 commit comments

Comments
 (0)