Skip to content

Commit 0d5a39f

Browse files
authored
br_if-to-table (#1313)
Implements #1309: subsequent br_ifs that compare the same value to various constants are converted into a br_table in a block, (br_if $x (i32.eq (get_local $a) (i32.const 0))) (br_if $y (i32.eq (get_local $a) (i32.const 1))) (br_if $z (i32.eq (get_local $a) (i32.const 2))) ==> (block $tablify (br_table $x $y $z $tablify (get_local $a) ) ) The constants for when to apply this (e.g., not if the range of values would make a huge jump table) are fairly conservative, I think, but hard to tell. Probably should be tweaked based on our experience with the pass in practice later on.
1 parent e91d1bf commit 0d5a39f

File tree

4 files changed

+968
-6
lines changed

4 files changed

+968
-6
lines changed

src/passes/RemoveUnusedBrs.cpp

Lines changed: 170 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include <wasm.h>
2222
#include <pass.h>
23+
#include <parsing.h>
2324
#include <ir/utils.h>
2425
#include <ir/branch-utils.h>
2526
#include <ir/effects.h>
@@ -444,9 +445,11 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
444445

445446
// perform some final optimizations
446447
struct FinalOptimizer : public PostWalker<FinalOptimizer> {
447-
bool selectify;
448+
bool shrink;
448449
PassOptions& passOptions;
449450

451+
bool needUniqify = false;
452+
450453
FinalOptimizer(PassOptions& passOptions) : passOptions(passOptions) {}
451454

452455
void visitBlock(Block* curr) {
@@ -479,9 +482,9 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
479482
}
480483
}
481484
if (list.size() >= 2) {
482-
if (selectify) {
483-
// Join adjacent br_ifs to the same target, making one br_if with
484-
// a "selectified" condition that executes both.
485+
// Join adjacent br_ifs to the same target, making one br_if with
486+
// a "selectified" condition that executes both.
487+
if (shrink) {
485488
for (Index i = 0; i < list.size() - 1; i++) {
486489
auto* br1 = list[i]->dynCast<Break>();
487490
// avoid unreachable brs, as they are dead code anyhow, and after merging
@@ -500,6 +503,9 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
500503
}
501504
}
502505
}
506+
// combine adjacent br_ifs that test the same value into a br_table,
507+
// when that makes sense
508+
tablify(curr);
503509
// Restructuring of ifs: if we have
504510
// (block $x
505511
// (br_if $x (cond))
@@ -535,7 +541,7 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
535541
void visitIf(If* curr) {
536542
// we may have simplified ifs enough to turn them into selects
537543
// this is helpful for code size, but can be a tradeoff with performance as we run both code paths
538-
if (!selectify) return;
544+
if (!shrink) return;
539545
if (curr->ifFalse && isConcreteWasmType(curr->ifTrue->type) && isConcreteWasmType(curr->ifFalse->type)) {
540546
// if with else, consider turning it into a select if there is no control flow
541547
// TODO: estimate cost
@@ -556,11 +562,169 @@ struct RemoveUnusedBrs : public WalkerPass<PostWalker<RemoveUnusedBrs>> {
556562
}
557563
}
558564
}
565+
566+
// (br_if)+ => br_table
567+
// we look for the specific pattern of
568+
// (br_if ..target1..
569+
// (i32.eq
570+
// (..input..)
571+
// (i32.const ..value1..)
572+
// )
573+
// )
574+
// (br_if ..target2..
575+
// (i32.eq
576+
// (..input..)
577+
// (i32.const ..value2..)
578+
// )
579+
// )
580+
// TODO: consider also looking at <= etc. and not just eq
581+
void tablify(Block* block) {
582+
auto &list = block->list;
583+
if (list.size() <= 1) return;
584+
585+
// Heuristics. These are slightly inspired by the constants from the asm.js backend.
586+
587+
// How many br_ifs we need to see to consider doing this
588+
const uint32_t MIN_NUM = 3;
589+
// How much of a range of values is definitely too big
590+
const uint32_t MAX_RANGE = 1024;
591+
// Multiplied by the number of br_ifs, then compared to the range. When
592+
// this is high, we allow larger ranges.
593+
const uint32_t NUM_TO_RANGE_FACTOR = 3;
594+
595+
// check if the input is a proper br_if on an i32.eq of a condition value to a const,
596+
// and the const is in the proper range, [0-int32_max), to avoid overflow concerns.
597+
// returns the br_if if so, or nullptr otherwise
598+
auto getProperBrIf = [](Expression* curr) -> Break*{
599+
auto* br = curr->dynCast<Break>();
600+
if (!br) return nullptr;
601+
if (!br->condition || br->value) return nullptr;
602+
if (br->type != none) return nullptr; // no value, so can be unreachable or none. ignore unreachable ones, dce will clean it up
603+
auto* binary = br->condition->dynCast<Binary>();
604+
if (!binary) return nullptr;
605+
if (binary->op != EqInt32) return nullptr;
606+
auto* c = binary->right->dynCast<Const>();
607+
if (!c) return nullptr;
608+
uint32_t value = c->value.geti32();
609+
if (value >= std::numeric_limits<int32_t>::max()) return nullptr;
610+
return br;
611+
};
612+
613+
// check if the input is a proper br_if
614+
// and returns the condition if so, or nullptr otherwise
615+
auto getProperBrIfConditionValue = [&getProperBrIf](Expression* curr) -> Expression* {
616+
auto* br = getProperBrIf(curr);
617+
if (!br) return nullptr;
618+
return br->condition->cast<Binary>()->left;
619+
};
620+
621+
// returns the constant value, as a uint32_t
622+
auto getProperBrIfConstant = [&getProperBrIf](Expression* curr) -> uint32_t {
623+
return getProperBrIf(curr)->condition->cast<Binary>()->right->cast<Const>()->value.geti32();
624+
};
625+
Index start = 0;
626+
while (start < list.size() - 1) {
627+
auto* conditionValue = getProperBrIfConditionValue(list[start]);
628+
if (!conditionValue) {
629+
start++;
630+
continue;
631+
}
632+
// if the condition has side effects, we can't replace many appearances of it
633+
// with a single one
634+
if (EffectAnalyzer(passOptions, conditionValue).hasSideEffects()) {
635+
start++;
636+
continue;
637+
}
638+
// look for a "run" of br_ifs with all the same conditionValue, and having
639+
// unique constants (an overlapping constant could be handled, just the first
640+
// branch is taken, but we can't remove the other br_if (it may be the only
641+
// branch keeping a block reachable), which may make this bad for code size.
642+
Index end = start + 1;
643+
std::unordered_set<uint32_t> usedConstants;
644+
usedConstants.insert(getProperBrIfConstant(list[start]));
645+
while (end < list.size() &&
646+
ExpressionAnalyzer::equal(getProperBrIfConditionValue(list[end]),
647+
conditionValue)) {
648+
if (!usedConstants.insert(getProperBrIfConstant(list[end])).second) {
649+
// this constant already appeared
650+
break;
651+
}
652+
end++;
653+
}
654+
auto num = end - start;
655+
if (num >= 2 && num >= MIN_NUM) {
656+
// we found a suitable range, [start, end), containing more than 1
657+
// element. let's see if it's worth it
658+
auto min = getProperBrIfConstant(list[start]);
659+
auto max = min;
660+
for (Index i = start + 1; i < end; i++) {
661+
auto* curr = list[i];
662+
min = std::min(min, getProperBrIfConstant(curr));
663+
max = std::max(max, getProperBrIfConstant(curr));
664+
}
665+
uint32_t range = max - min;
666+
// decision time
667+
if (range <= MAX_RANGE &&
668+
range <= num * NUM_TO_RANGE_FACTOR) {
669+
// great! let's do this
670+
std::unordered_set<Name> usedNames;
671+
for (Index i = start; i < end; i++) {
672+
usedNames.insert(getProperBrIf(list[i])->name);
673+
}
674+
// we need a name for the default too
675+
Name defaultName;
676+
Index i = 0;
677+
while (1) {
678+
defaultName = "tablify|" + std::to_string(i++);
679+
if (usedNames.count(defaultName) == 0) break;
680+
}
681+
std::vector<Name> table;
682+
for (Index i = start; i < end; i++) {
683+
auto name = getProperBrIf(list[i])->name;
684+
auto index = getProperBrIfConstant(list[i]);
685+
index -= min;
686+
while (table.size() <= index) {
687+
table.push_back(defaultName);
688+
}
689+
assert(table[index] == defaultName); // we should have made sure there are no overlaps
690+
table[index] = name;
691+
}
692+
Builder builder(*getModule());
693+
// the table and condition are offset by the min
694+
if (min != 0) {
695+
conditionValue = builder.makeBinary(
696+
SubInt32,
697+
conditionValue,
698+
builder.makeConst(Literal(int32_t(min)))
699+
);
700+
}
701+
list[end - 1] = builder.makeBlock(
702+
defaultName,
703+
builder.makeSwitch(
704+
table,
705+
defaultName,
706+
conditionValue
707+
)
708+
);
709+
for (Index i = start; i < end - 1; i++) {
710+
ExpressionManipulator::nop(list[i]);
711+
}
712+
// the defaultName may exist elsewhere in this function,
713+
// uniquify it later
714+
needUniqify = true;
715+
}
716+
}
717+
start = end;
718+
}
719+
}
559720
};
560721
FinalOptimizer finalOptimizer(getPassOptions());
561722
finalOptimizer.setModule(getModule());
562-
finalOptimizer.selectify = getPassRunner()->options.shrinkLevel > 0;
723+
finalOptimizer.shrink = getPassRunner()->options.shrinkLevel > 0;
563724
finalOptimizer.walkFunction(func);
725+
if (finalOptimizer.needUniqify) {
726+
wasm::UniqueNameMapper::uniquify(func->body);
727+
}
564728
}
565729
};
566730

src/wasm-builder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class Builder {
8282
Block* makeBlock(Name name, Expression* first = nullptr) {
8383
auto* ret = makeBlock(first);
8484
ret->name = name;
85+
ret->finalize();
8586
return ret;
8687
}
8788
Block* makeBlock(const std::vector<Expression*>& items) {

0 commit comments

Comments
 (0)