Skip to content

Commit 0b6a1f5

Browse files
authored
Optimize validation of many nested blocks (#1576)
On the testcase from tweag/asterius#19 (comment) this makes us almost 3x faster, and use 25% less memory. The main improvement here is to simplify and optimize the data structures the validator uses to validate br targets: use unordered maps, and use one less of them. Also some speedups from using that map more effectively (use of iterators to avoid multiple lookups). Also move the duplicate-node checks to the internal IR validation section, which makes more sense anyhow (it's not wasm validation, it's internal IR validation, which like the check for stale internal types, we do only if debugging).
1 parent 706b3f6 commit 0b6a1f5

File tree

4 files changed

+50
-72
lines changed

4 files changed

+50
-72
lines changed

src/emscripten-optimizer/istring.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,7 @@ namespace std {
159159

160160
template <> struct hash<cashew::IString> : public unary_function<cashew::IString, size_t> {
161161
size_t operator()(const cashew::IString& str) const {
162-
size_t hash = size_t(str.str);
163-
return hash = ((hash << 5) + hash) ^ 5381; /* (hash * 33) ^ c */
162+
return std::hash<size_t>{}(size_t(str.str));
164163
}
165164
};
166165

src/wasm/wasm-validator.cpp

Lines changed: 49 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -176,20 +176,27 @@ struct FunctionValidator : public WalkerPass<PostWalker<FunctionValidator>> {
176176
FunctionValidator(ValidationInfo* info) : info(*info) {}
177177

178178
struct BreakInfo {
179+
enum {
180+
UnsetArity = Index(-1),
181+
PoisonArity = Index(-2)
182+
};
183+
179184
Type type;
180185
Index arity;
181-
BreakInfo() {}
186+
BreakInfo() : arity(UnsetArity) {}
182187
BreakInfo(Type type, Index arity) : type(type), arity(arity) {}
188+
189+
bool hasBeenSet() {
190+
// Compare to the impossible value.
191+
return arity != UnsetArity;
192+
}
183193
};
184194

185-
std::map<Name, Expression*> breakTargets;
186-
std::map<Expression*, BreakInfo> breakInfos;
195+
std::unordered_map<Name, BreakInfo> breakInfos;
187196

188197
Type returnType = unreachable; // type used in returns
189198

190-
std::set<Name> labelNames; // Binaryen IR requires that label names must be unique - IR generators must ensure that
191-
192-
std::unordered_set<Expression*> seenExpressions; // expressions must not appear twice
199+
std::unordered_set<Name> labelNames; // Binaryen IR requires that label names must be unique - IR generators must ensure that
193200

194201
void noteLabelName(Name name);
195202

@@ -198,14 +205,14 @@ struct FunctionValidator : public WalkerPass<PostWalker<FunctionValidator>> {
198205

199206
static void visitPreBlock(FunctionValidator* self, Expression** currp) {
200207
auto* curr = (*currp)->cast<Block>();
201-
if (curr->name.is()) self->breakTargets[curr->name] = curr;
208+
if (curr->name.is()) self->breakInfos[curr->name];
202209
}
203210

204211
void visitBlock(Block* curr);
205212

206213
static void visitPreLoop(FunctionValidator* self, Expression** currp) {
207214
auto* curr = (*currp)->cast<Loop>();
208-
if (curr->name.is()) self->breakTargets[curr->name] = curr;
215+
if (curr->name.is()) self->breakInfos[curr->name];
209216
}
210217

211218
void visitLoop(Loop* curr);
@@ -285,16 +292,19 @@ struct FunctionValidator : public WalkerPass<PostWalker<FunctionValidator>> {
285292

286293
void FunctionValidator::noteLabelName(Name name) {
287294
if (!name.is()) return;
288-
shouldBeTrue(labelNames.find(name) == labelNames.end(), name, "names in Binaryen IR must be unique - IR generators must ensure that");
289-
labelNames.insert(name);
295+
bool inserted;
296+
std::tie(std::ignore, inserted) = labelNames.insert(name);
297+
shouldBeTrue(inserted, name, "names in Binaryen IR must be unique - IR generators must ensure that");
290298
}
291299

292300
void FunctionValidator::visitBlock(Block* curr) {
293301
// if we are break'ed to, then the value must be right for us
294302
if (curr->name.is()) {
295303
noteLabelName(curr->name);
296-
if (breakInfos.count(curr) > 0) {
297-
auto& info = breakInfos[curr];
304+
auto iter = breakInfos.find(curr->name);
305+
assert(iter != breakInfos.end()); // we set it ourselves
306+
auto& info = iter->second;
307+
if (info.hasBeenSet()) {
298308
if (isConcreteType(curr->type)) {
299309
shouldBeTrue(info.arity != 0, curr, "break arities must be > 0 if block has a value");
300310
} else {
@@ -307,7 +317,7 @@ void FunctionValidator::visitBlock(Block* curr) {
307317
if (isConcreteType(curr->type) && info.arity && info.type != unreachable) {
308318
shouldBeEqual(curr->type, info.type, curr, "block+breaks must have right type if breaks have arity");
309319
}
310-
shouldBeTrue(info.arity != Index(-1), curr, "break arities must match");
320+
shouldBeTrue(info.arity != BreakInfo::PoisonArity, curr, "break arities must match");
311321
if (curr->list.size() > 0) {
312322
auto last = curr->list.back()->type;
313323
if (isConcreteType(last) && info.type != unreachable) {
@@ -318,7 +328,7 @@ void FunctionValidator::visitBlock(Block* curr) {
318328
}
319329
}
320330
}
321-
breakTargets.erase(curr->name);
331+
breakInfos.erase(iter);
322332
}
323333
if (curr->list.size() > 1) {
324334
for (Index i = 0; i < curr->list.size() - 1; i++) {
@@ -347,11 +357,13 @@ void FunctionValidator::visitBlock(Block* curr) {
347357
void FunctionValidator::visitLoop(Loop* curr) {
348358
if (curr->name.is()) {
349359
noteLabelName(curr->name);
350-
breakTargets.erase(curr->name);
351-
if (breakInfos.count(curr) > 0) {
352-
auto& info = breakInfos[curr];
360+
auto iter = breakInfos.find(curr->name);
361+
assert(iter != breakInfos.end()); // we set it ourselves
362+
auto& info = iter->second;
363+
if (info.hasBeenSet()) {
353364
shouldBeEqual(info.arity, Index(0), curr, "breaks to a loop cannot pass a value");
354365
}
366+
breakInfos.erase(iter);
355367
}
356368
if (curr->type == none) {
357369
shouldBeFalse(isConcreteType(curr->body->type), curr, "bad body for a loop that has no value");
@@ -394,12 +406,12 @@ void FunctionValidator::noteBreak(Name name, Expression* value, Expression* curr
394406
shouldBeUnequal(valueType, none, curr, "breaks must have a valid value");
395407
arity = 1;
396408
}
397-
if (!shouldBeTrue(breakTargets.count(name) > 0, curr, "all break targets must be valid")) return;
398-
auto* target = breakTargets[name];
399-
if (breakInfos.count(target) == 0) {
400-
breakInfos[target] = BreakInfo(valueType, arity);
409+
auto iter = breakInfos.find(name);
410+
if (!shouldBeTrue(iter != breakInfos.end(), curr, "all break targets must be valid")) return;
411+
auto& info = iter->second;
412+
if (!info.hasBeenSet()) {
413+
info = BreakInfo(valueType, arity);
401414
} else {
402-
auto& info = breakInfos[target];
403415
if (info.type == unreachable) {
404416
info.type = valueType;
405417
} else if (valueType != unreachable) {
@@ -408,7 +420,7 @@ void FunctionValidator::noteBreak(Name name, Expression* value, Expression* curr
408420
}
409421
}
410422
if (arity != info.arity) {
411-
info.arity = Index(-1); // a poison value
423+
info.arity = BreakInfo::PoisonArity;
412424
}
413425
}
414426
}
@@ -810,7 +822,7 @@ void FunctionValidator::visitFunction(Function* curr) {
810822
if (returnType != unreachable) {
811823
shouldBeEqual(curr->result, returnType, curr->body, "function result must match, if function has returns");
812824
}
813-
shouldBeTrue(breakTargets.empty(), curr->body, "all named break targets must exist");
825+
shouldBeTrue(breakInfos.empty(), curr->body, "all named break targets must exist");
814826
returnType = unreachable;
815827
labelNames.clear();
816828
// if function has a named type, it must match up with the function's params and result
@@ -819,24 +831,6 @@ void FunctionValidator::visitFunction(Function* curr) {
819831
shouldBeTrue(ft->params == curr->params, curr->name, "function params must match its declared type");
820832
shouldBeTrue(ft->result == curr->result, curr->name, "function result must match its declared type");
821833
}
822-
// expressions must not be seen more than once
823-
struct Walker : public PostWalker<Walker, UnifiedExpressionVisitor<Walker>> {
824-
std::unordered_set<Expression*>& seen;
825-
std::vector<Expression*> dupes;
826-
827-
Walker(std::unordered_set<Expression*>& seen) : seen(seen) {}
828-
829-
void visitExpression(Expression* curr) {
830-
bool inserted;
831-
std::tie(std::ignore, inserted) = seen.insert(curr);
832-
if (!inserted) dupes.push_back(curr);
833-
}
834-
};
835-
Walker walker(seenExpressions);
836-
walker.walk(curr->body);
837-
for (auto* bad : walker.dupes) {
838-
info.fail("expression seen more than once in the tree", bad, getFunction());
839-
}
840834
}
841835

842836
static bool checkOffset(Expression* curr, Address add, Address max) {
@@ -890,9 +884,12 @@ static void validateBinaryenIR(Module& wasm, ValidationInfo& info) {
890884
struct BinaryenIRValidator : public PostWalker<BinaryenIRValidator, UnifiedExpressionVisitor<BinaryenIRValidator>> {
891885
ValidationInfo& info;
892886

887+
std::unordered_set<Expression*> seen;
888+
893889
BinaryenIRValidator(ValidationInfo& info) : info(info) {}
894890

895891
void visitExpression(Expression* curr) {
892+
auto scope = getFunction() ? getFunction()->name : Name("(global scope)");
896893
// check if a node type is 'stale', i.e., we forgot to finalize() the node.
897894
auto oldType = curr->type;
898895
ReFinalizeNode().visit(curr);
@@ -907,11 +904,19 @@ static void validateBinaryenIR(Module& wasm, ValidationInfo& info) {
907904
// ok for it to be either i32 or unreachable.
908905
if (!(isConcreteType(oldType) && newType == unreachable)) {
909906
std::ostringstream ss;
910-
ss << "stale type found in " << (getFunction() ? getFunction()->name : Name("(global scope)")) << " on " << curr << "\n(marked as " << printType(oldType) << ", should be " << printType(newType) << ")\n";
907+
ss << "stale type found in " << scope << " on " << curr << "\n(marked as " << printType(oldType) << ", should be " << printType(newType) << ")\n";
911908
info.fail(ss.str(), curr, getFunction());
912909
}
913910
curr->type = oldType;
914911
}
912+
// check if a node is a duplicate - expressions must not be seen more than once
913+
bool inserted;
914+
std::tie(std::ignore, inserted) = seen.insert(curr);
915+
if (!inserted) {
916+
std::ostringstream ss;
917+
ss << "expression seen more than once in the tree in " << scope << " on " << curr << '\n';
918+
info.fail(ss.str(), curr, getFunction());
919+
}
915920
}
916921
};
917922
BinaryenIRValidator binaryenIRValidator(info);
@@ -952,7 +957,7 @@ static void validateExports(Module& module, ValidationInfo& info) {
952957
}
953958
}
954959
}
955-
std::set<Name> exportNames;
960+
std::unordered_set<Name> exportNames;
956961
for (auto& exp : module.exports) {
957962
Name name = exp->value;
958963
if (exp->kind == ExternalKind::Function) {

test/example/c-api-kitchen-sink.c

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -567,22 +567,6 @@ void test_nonvalid() {
567567
BinaryenModulePrint(module);
568568
printf("validation: %d\n", BinaryenModuleValidate(module));
569569

570-
BinaryenModuleDispose(module);
571-
}
572-
// validation failure due to duplicate nodes
573-
{
574-
BinaryenModuleRef module = BinaryenModuleCreate();
575-
576-
BinaryenFunctionTypeRef v = BinaryenAddFunctionType(module, "i", BinaryenTypeInt32(), NULL, 0);
577-
BinaryenType localTypes[] = { };
578-
BinaryenExpressionRef num = makeInt32(module, 1234);
579-
BinaryenFunctionRef func = BinaryenAddFunction(module, "func", v, NULL, 0,
580-
BinaryenBinary(module, BinaryenTypeInt32(), num, num) // incorrectly use num twice
581-
);
582-
583-
BinaryenModulePrint(module);
584-
printf("validation: %d\n", BinaryenModuleValidate(module));
585-
586570
BinaryenModuleDispose(module);
587571
}
588572
}

test/example/c-api-kitchen-sink.txt

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,16 +1070,6 @@ module loaded from binary form:
10701070
)
10711071
)
10721072
validation: 0
1073-
(module
1074-
(type $i (func (result i32)))
1075-
(func $func (; 0 ;) (type $i) (result i32)
1076-
(i32.sub
1077-
(i32.const 1234)
1078-
(i32.const 1234)
1079-
)
1080-
)
1081-
)
1082-
validation: 0
10831073
// beginning a Binaryen API trace
10841074
#include <math.h>
10851075
#include <map>

0 commit comments

Comments
 (0)