Skip to content

Commit d216f87

Browse files
authored
wasm2js: avoid reinterprets (#2094)
In JS a reinterpret is especially expensive, as we implement it as a write to a temp buffer and a read using another view. This finds places where we load a value from memory, then reinterpret it later - in that case, we can load it using another view, at the cost of another load and another local. This is helpful on things like Box2D, where there are many reinterprets due to the main 2D vector class being an union over two floats/ints, and LLVM likes to do a single i64 load of them.
1 parent bdfdbfb commit d216f87

21 files changed

+689
-326
lines changed

build-js.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ echo "building shared bitcode"
9292
$BINARYEN_SRC/ir/ReFinalize.cpp \
9393
$BINARYEN_SRC/passes/pass.cpp \
9494
$BINARYEN_SRC/passes/AlignmentLowering.cpp \
95+
$BINARYEN_SRC/passes/AvoidReinterprets.cpp \
9596
$BINARYEN_SRC/passes/CoalesceLocals.cpp \
9697
$BINARYEN_SRC/passes/DeadArgumentElimination.cpp \
9798
$BINARYEN_SRC/passes/CodeFolding.cpp \

src/passes/AvoidReinterprets.cpp

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
/*
2+
* Copyright 2017 WebAssembly Community Group participants
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
// Avoids reinterprets by using more loads: if we load a value and
18+
// reinterpret it, we could have loaded it with the other type
19+
// anyhow. This uses more locals and loads, so it is not generally
20+
// beneficial, unless reinterprets are very costly.
21+
22+
#include <ir/local-graph.h>
23+
#include <ir/properties.h>
24+
#include <pass.h>
25+
#include <wasm-builder.h>
26+
#include <wasm.h>
27+
28+
namespace wasm {
29+
30+
static Load* getSingleLoad(LocalGraph* localGraph, GetLocal* get) {
31+
while (1) {
32+
auto& sets = localGraph->getSetses[get];
33+
if (sets.size() != 1) {
34+
return nullptr;
35+
}
36+
auto* set = *sets.begin();
37+
if (!set) {
38+
return nullptr;
39+
}
40+
auto* value = Properties::getFallthrough(set->value);
41+
if (auto* parentGet = value->dynCast<GetLocal>()) {
42+
get = parentGet;
43+
continue;
44+
}
45+
if (auto* load = value->dynCast<Load>()) {
46+
return load;
47+
}
48+
return nullptr;
49+
}
50+
}
51+
52+
static bool isReinterpret(Unary* curr) {
53+
return curr->op == ReinterpretInt32 || curr->op == ReinterpretInt64 ||
54+
curr->op == ReinterpretFloat32 || curr->op == ReinterpretFloat64;
55+
}
56+
57+
struct AvoidReinterprets : public WalkerPass<PostWalker<AvoidReinterprets>> {
58+
bool isFunctionParallel() override { return true; }
59+
60+
Pass* create() override { return new AvoidReinterprets; }
61+
62+
struct Info {
63+
// Info used when analyzing.
64+
bool reinterpreted;
65+
// Info used when optimizing.
66+
Index ptrLocal;
67+
Index reinterpretedLocal;
68+
};
69+
std::map<Load*, Info> infos;
70+
71+
LocalGraph* localGraph;
72+
73+
void doWalkFunction(Function* func) {
74+
// prepare
75+
LocalGraph localGraph_(func);
76+
localGraph = &localGraph_;
77+
// walk
78+
PostWalker<AvoidReinterprets>::doWalkFunction(func);
79+
// optimize
80+
optimize(func);
81+
}
82+
83+
void visitUnary(Unary* curr) {
84+
if (isReinterpret(curr)) {
85+
if (auto* get =
86+
Properties::getFallthrough(curr->value)->dynCast<GetLocal>()) {
87+
if (auto* load = getSingleLoad(localGraph, get)) {
88+
auto& info = infos[load];
89+
info.reinterpreted = true;
90+
}
91+
}
92+
}
93+
}
94+
95+
void optimize(Function* func) {
96+
std::set<Load*> unoptimizables;
97+
for (auto& pair : infos) {
98+
auto* load = pair.first;
99+
auto& info = pair.second;
100+
if (info.reinterpreted && load->type != unreachable) {
101+
// We should use another load here, to avoid reinterprets.
102+
info.ptrLocal = Builder::addVar(func, i32);
103+
info.reinterpretedLocal =
104+
Builder::addVar(func, reinterpretType(load->type));
105+
} else {
106+
unoptimizables.insert(load);
107+
}
108+
}
109+
for (auto* load : unoptimizables) {
110+
infos.erase(load);
111+
}
112+
// We now know which we can optimize, and how.
113+
struct FinalOptimizer : public PostWalker<FinalOptimizer> {
114+
std::map<Load*, Info>& infos;
115+
LocalGraph* localGraph;
116+
Module* module;
117+
118+
FinalOptimizer(std::map<Load*, Info>& infos,
119+
LocalGraph* localGraph,
120+
Module* module)
121+
: infos(infos), localGraph(localGraph), module(module) {}
122+
123+
void visitUnary(Unary* curr) {
124+
if (isReinterpret(curr)) {
125+
auto* value = Properties::getFallthrough(curr->value);
126+
if (auto* load = value->dynCast<Load>()) {
127+
// A reinterpret of a load - flip it right here.
128+
replaceCurrent(makeReinterpretedLoad(load, load->ptr));
129+
} else if (auto* get = value->dynCast<GetLocal>()) {
130+
if (auto* load = getSingleLoad(localGraph, get)) {
131+
auto iter = infos.find(load);
132+
if (iter != infos.end()) {
133+
auto& info = iter->second;
134+
// A reinterpret of a get of a load - use the new local.
135+
Builder builder(*module);
136+
replaceCurrent(builder.makeGetLocal(
137+
info.reinterpretedLocal, reinterpretType(load->type)));
138+
}
139+
}
140+
}
141+
}
142+
}
143+
144+
void visitLoad(Load* curr) {
145+
auto iter = infos.find(curr);
146+
if (iter != infos.end()) {
147+
auto& info = iter->second;
148+
Builder builder(*module);
149+
auto* ptr = curr->ptr;
150+
curr->ptr = builder.makeGetLocal(info.ptrLocal, i32);
151+
// Note that the other load can have its sign set to false - if the
152+
// original were an integer, the other is a float anyhow; and if
153+
// original were a float, we don't know what sign to use.
154+
replaceCurrent(builder.makeBlock(
155+
{builder.makeSetLocal(info.ptrLocal, ptr),
156+
builder.makeSetLocal(
157+
info.reinterpretedLocal,
158+
makeReinterpretedLoad(curr,
159+
builder.makeGetLocal(info.ptrLocal, i32))),
160+
curr}));
161+
}
162+
}
163+
164+
Load* makeReinterpretedLoad(Load* load, Expression* ptr) {
165+
Builder builder(*module);
166+
return builder.makeLoad(load->bytes,
167+
false,
168+
load->offset,
169+
load->align,
170+
ptr,
171+
reinterpretType(load->type));
172+
}
173+
} finalOptimizer(infos, localGraph, getModule());
174+
175+
finalOptimizer.walk(func->body);
176+
}
177+
};
178+
179+
Pass* createAvoidReinterpretsPass() { return new AvoidReinterprets(); }
180+
181+
} // namespace wasm

src/passes/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ add_custom_command(
66
SET(passes_SOURCES
77
pass.cpp
88
AlignmentLowering.cpp
9+
AvoidReinterprets.cpp
910
CoalesceLocals.cpp
1011
CodePushing.cpp
1112
CodeFolding.cpp

src/passes/pass.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,14 @@ std::string PassRegistry::getPassDescription(std::string name) {
7171
// PassRunner
7272

7373
void PassRegistry::registerPasses() {
74-
registerPass(
75-
"dae", "removes arguments to calls in an lto-like manner", createDAEPass);
7674
registerPass("alignment-lowering",
7775
"lower unaligned loads and stores to smaller aligned ones",
7876
createAlignmentLoweringPass);
77+
registerPass("avoid-reinterprets",
78+
"Tries to avoid reinterpret operations via more loads",
79+
createAvoidReinterpretsPass);
80+
registerPass(
81+
"dae", "removes arguments to calls in an lto-like manner", createDAEPass);
7982
registerPass("dae-optimizing",
8083
"removes arguments to calls in an lto-like manner, and "
8184
"optimizes where we removed",

src/passes/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class Pass;
2323

2424
// All passes:
2525
Pass* createAlignmentLoweringPass();
26+
Pass* createAvoidReinterpretsPass();
2627
Pass* createCoalesceLocalsPass();
2728
Pass* createCoalesceLocalsWithLearningPass();
2829
Pass* createCodeFoldingPass();

src/wasm-type.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ bool isFloatType(Type type);
4545
bool isIntegerType(Type type);
4646
bool isVectorType(Type type);
4747
bool isReferenceType(Type type);
48+
Type reinterpretType(Type type);
4849

4950
} // namespace wasm
5051

src/wasm/wasm-type.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,23 @@ bool isReferenceType(Type type) {
118118
return type == except_ref;
119119
}
120120

121+
Type reinterpretType(Type type) {
122+
switch (type) {
123+
case Type::i32:
124+
return f32;
125+
case Type::i64:
126+
return f64;
127+
case Type::f32:
128+
return i32;
129+
case Type::f64:
130+
return i64;
131+
case Type::v128:
132+
case Type::except_ref:
133+
case Type::none:
134+
case Type::unreachable:
135+
WASM_UNREACHABLE();
136+
}
137+
WASM_UNREACHABLE();
138+
}
139+
121140
} // namespace wasm

src/wasm2js.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,18 @@ Ref Wasm2JSBuilder::processWasm(Module* wasm, Name funcName) {
296296
// Next, optimize that as best we can. This should not generate
297297
// non-JS-friendly things.
298298
if (options.optimizeLevel > 0) {
299+
// It is especially import to propagate constants after the lowering.
300+
// However, this can be a slow operation, especially after flattening;
301+
// some local simplification helps.
302+
if (options.optimizeLevel >= 3 || options.shrinkLevel >= 1) {
303+
runner.add("simplify-locals-nonesting");
304+
runner.add("precompute-propagate");
305+
// Avoiding reinterpretation is helped by propagation. We also run
306+
// it later down as default optimizations help as well.
307+
runner.add("avoid-reinterprets");
308+
}
299309
runner.addDefaultOptimizationPasses();
310+
runner.add("avoid-reinterprets");
300311
}
301312
// Finally, get the code into the flat form we need for wasm2js itself, and
302313
// optimize that a little in a way that keeps flat property.

0 commit comments

Comments
 (0)