Skip to content

Commit cd3b26d

Browse files
authored
Require RefFunc to have the proper type (WebAssembly#7376)
As a holdout from before GC was implemented, we previously allowed RefFunc expressions to have type `funcref` rather than a specific signature type matching that of the referenced function. Remove this allowance and start requiring the types to be correct and precise to eliminate the possibility of stale types inhibiting (or invalidating!) optimizations. Update various older passes to update the types of RefFuncs, including those in tables, to keep their output passing validation. Also update the kitchen sink example test to construct RefFunc expressions with the correct type via the C API.
1 parent 3f341e5 commit cd3b26d

20 files changed

+146
-85
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ Current Trunk
2121
- `string` is now a subtype of `ext` (rather than `any`). This allows better
2222
transformations for strings, like an inverse of StringLowering, but will
2323
error on codebases that depend on being able to pass strings into anyrefs.
24+
- Require the type of RefFunc expressions to match the type of the referenced
25+
function. It is no longer valid to type them as funcref in the IR.
26+
- The C and JS APIs for creating RefFunc expressions now take a HeapType
27+
instead of a Type.
2428

2529
v122
2630
----

src/binaryen-c.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,12 +1606,11 @@ BinaryenExpressionRef BinaryenRefAs(BinaryenModuleRef module,
16061606
Builder(*(Module*)module).makeRefAs(RefAsOp(op), (Expression*)value));
16071607
}
16081608

1609-
BinaryenExpressionRef
1610-
BinaryenRefFunc(BinaryenModuleRef module, const char* func, BinaryenType type) {
1611-
// TODO: consider changing the C API to receive a heap type
1612-
Type type_(type);
1609+
BinaryenExpressionRef BinaryenRefFunc(BinaryenModuleRef module,
1610+
const char* func,
1611+
BinaryenHeapType type) {
16131612
return static_cast<Expression*>(
1614-
Builder(*(Module*)module).makeRefFunc(func, type_.getHeapType()));
1613+
Builder(*(Module*)module).makeRefFunc(func, HeapType(type)));
16151614
}
16161615

16171616
BinaryenExpressionRef BinaryenRefEq(BinaryenModuleRef module,
@@ -6259,8 +6258,12 @@ bool TypeBuilderBuildAndDispose(TypeBuilderRef builder,
62596258
auto* B = (TypeBuilder*)builder;
62606259
auto result = B->build();
62616260
if (auto err = result.getError()) {
6262-
*errorIndex = err->index;
6263-
*errorReason = static_cast<TypeBuilderErrorReason>(err->reason);
6261+
if (errorIndex) {
6262+
*errorIndex = err->index;
6263+
}
6264+
if (errorReason) {
6265+
*errorReason = static_cast<TypeBuilderErrorReason>(err->reason);
6266+
}
62646267
delete B;
62656268
return false;
62666269
}

src/binaryen-c.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -961,7 +961,7 @@ BINARYEN_API BinaryenExpressionRef BinaryenRefAs(BinaryenModuleRef module,
961961
BinaryenExpressionRef value);
962962
BINARYEN_API BinaryenExpressionRef BinaryenRefFunc(BinaryenModuleRef module,
963963
const char* func,
964-
BinaryenType type);
964+
BinaryenHeapType type);
965965
BINARYEN_API BinaryenExpressionRef BinaryenRefEq(BinaryenModuleRef module,
966966
BinaryenExpressionRef left,
967967
BinaryenExpressionRef right);

src/ir/element-utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ template<typename T>
4242
inline void iterAllElementFunctionNames(const Module* wasm, T visitor) {
4343
for (auto& segment : wasm->elementSegments) {
4444
iterElementSegmentFunctionNames(
45-
segment.get(), [&](Name& name, Index i) { visitor(name); });
45+
segment.get(), [&](const Name& name, Index i) { visitor(name); });
4646
}
4747
}
4848

src/js/binaryen.js-post.js

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3275,6 +3275,7 @@ Module['getFunctionInfo'] = function(func) {
32753275
'name': UTF8ToString(Module['_BinaryenFunctionGetName'](func)),
32763276
'module': UTF8ToString(Module['_BinaryenFunctionImportGetModule'](func)),
32773277
'base': UTF8ToString(Module['_BinaryenFunctionImportGetBase'](func)),
3278+
'type': Module['_BinaryenFunctionGetType'](func),
32783279
'params': Module['_BinaryenFunctionGetParams'](func),
32793280
'results': Module['_BinaryenFunctionGetResults'](func),
32803281
'vars': getAllNested(func, Module['_BinaryenFunctionGetNumVars'], Module['_BinaryenFunctionGetVar']),
@@ -4893,6 +4894,9 @@ Module['Function'] = (() => {
48934894
Function['getName'] = function(func) {
48944895
return UTF8ToString(Module['_BinaryenFunctionGetName'](func));
48954896
};
4897+
Function['getType'] = function(func) {
4898+
return Module['_BinaryenFunctionGetType'](func);
4899+
}
48964900
Function['getParams'] = function(func) {
48974901
return Module['_BinaryenFunctionGetParams'](func);
48984902
};

src/passes/FuncCastEmulation.cpp

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -162,25 +162,34 @@ struct FuncCastEmulation : public Pass {
162162
HeapType ABIType(
163163
Signature(Type(std::vector<Type>(numParams, Type::i64)), Type::i64));
164164
// Add a thunk for each function in the table, and do the call through it.
165-
std::unordered_map<Name, Name> funcThunks;
166-
ElementUtils::iterAllElementFunctionNames(module, [&](Name& name) {
167-
auto iter = funcThunks.find(name);
168-
if (iter == funcThunks.end()) {
169-
auto thunk = makeThunk(name, module, numParams);
170-
funcThunks[name] = thunk;
171-
name = thunk;
172-
} else {
173-
name = iter->second;
165+
std::unordered_map<Name, Function*> funcThunks;
166+
for (auto& segment : module->elementSegments) {
167+
if (!segment->type.isFunction()) {
168+
continue;
174169
}
175-
});
170+
for (Index i = 0; i < segment->data.size(); ++i) {
171+
auto* ref = segment->data[i]->dynCast<RefFunc>();
172+
if (!ref) {
173+
continue;
174+
}
175+
auto [iter, inserted] = funcThunks.insert({ref->func, nullptr});
176+
if (inserted) {
177+
iter->second = makeThunk(ref->func, module, numParams);
178+
}
179+
auto* thunk = iter->second;
180+
ref->func = thunk->name;
181+
// TODO: Make this exact.
182+
ref->type = Type(thunk->type, NonNullable);
183+
}
184+
}
176185

177186
// update call_indirects
178187
ParallelFuncCastEmulation(ABIType, numParams).run(getPassRunner(), module);
179188
}
180189

181190
private:
182191
// Creates a thunk for a function, casting args and return value as needed.
183-
Name makeThunk(Name name, Module* module, Index numParams) {
192+
Function* makeThunk(Name name, Module* module, Index numParams) {
184193
Name thunk = std::string("byn$fpcast-emu$") + name.toString();
185194
if (module->getFunctionOrNull(thunk)) {
186195
Fatal() << "FuncCastEmulation::makeThunk seems a thunk name already in "
@@ -207,8 +216,7 @@ struct FuncCastEmulation : public Pass {
207216
{}, // no vars
208217
toABI(call, module));
209218
thunkFunc->hasExplicitName = true;
210-
module->addFunction(std::move(thunkFunc));
211-
return thunk;
219+
return module->addFunction(std::move(thunkFunc));
212220
}
213221
};
214222

src/passes/I64ToI32Lowering.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,39 @@ struct I64ToI32Lowering : public WalkerPass<PostWalker<I64ToI32Lowering>> {
297297
});
298298
}
299299

300+
void visitRefFunc(RefFunc* curr) {
301+
auto sig = curr->type.getHeapType().getSignature();
302+
303+
auto lowerTypes = [](Type types) {
304+
bool hasI64 = false;
305+
for (auto t : types) {
306+
if (t == Type::i64) {
307+
hasI64 = true;
308+
break;
309+
}
310+
}
311+
if (!hasI64) {
312+
return types;
313+
}
314+
std::vector<Type> newTypes;
315+
for (auto t : types) {
316+
if (t == Type::i64) {
317+
newTypes.push_back(Type::i32);
318+
newTypes.push_back(Type::i32);
319+
} else {
320+
newTypes.push_back(t);
321+
}
322+
}
323+
return Type(newTypes);
324+
};
325+
326+
auto newParams = lowerTypes(sig.params);
327+
auto newResults = lowerTypes(sig.results);
328+
if (newParams != sig.params || newResults != sig.results) {
329+
curr->type = curr->type.with(HeapType(Signature(newParams, newResults)));
330+
}
331+
}
332+
300333
void visitLocalGet(LocalGet* curr) {
301334
const auto mappedIndex = indexMap[curr->index];
302335
// Need to remap the local into the new naming scheme, regardless of

src/passes/LegalizeJSInterface.cpp

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -108,57 +108,52 @@ struct LegalizeJSInterface : public Pass {
108108
// for each illegal import, we must call a legalized stub instead
109109
for (auto* im : originalFunctions) {
110110
if (im->imported() && isIllegal(im)) {
111-
auto funcName = makeLegalStubForCalledImport(im, module);
112-
illegalImportsToLegal[im->name] = funcName;
113-
// we need to use the legalized version in the tables, as the import
114-
// from JS is legal for JS. Our stub makes it look like a native wasm
115-
// function.
116-
ElementUtils::iterAllElementFunctionNames(module, [&](Name& name) {
117-
if (name == im->name) {
118-
name = funcName;
119-
}
120-
});
111+
auto* func = makeLegalStubForCalledImport(im, module);
112+
illegalImportsToLegal[im->name] = func;
121113
}
122114
}
123115

124116
if (!illegalImportsToLegal.empty()) {
125-
// fix up imports: call_import of an illegal must be turned to a call of a
126-
// legal. the same must be done with ref.funcs.
117+
// fix up imports: call of an illegal import must be turned to a call of a
118+
// legal import. the same must be done with ref.funcs.
127119
struct Fixer : public WalkerPass<PostWalker<Fixer>> {
128120
bool isFunctionParallel() override { return true; }
129121

130122
std::unique_ptr<Pass> create() override {
131123
return std::make_unique<Fixer>(illegalImportsToLegal);
132124
}
133125

134-
std::map<Name, Name>* illegalImportsToLegal;
126+
std::unordered_map<Name, Function*>& illegalImportsToLegal;
135127

136-
Fixer(std::map<Name, Name>* illegalImportsToLegal)
128+
Fixer(std::unordered_map<Name, Function*>& illegalImportsToLegal)
137129
: illegalImportsToLegal(illegalImportsToLegal) {}
138130

139131
void visitCall(Call* curr) {
140-
auto iter = illegalImportsToLegal->find(curr->target);
141-
if (iter == illegalImportsToLegal->end()) {
132+
auto iter = illegalImportsToLegal.find(curr->target);
133+
if (iter == illegalImportsToLegal.end()) {
142134
return;
143135
}
144136

145-
replaceCurrent(
146-
Builder(*getModule())
147-
.makeCall(
148-
iter->second, curr->operands, curr->type, curr->isReturn));
137+
replaceCurrent(Builder(*getModule())
138+
.makeCall(iter->second->name,
139+
curr->operands,
140+
curr->type,
141+
curr->isReturn));
149142
}
150143

151144
void visitRefFunc(RefFunc* curr) {
152-
auto iter = illegalImportsToLegal->find(curr->func);
153-
if (iter == illegalImportsToLegal->end()) {
145+
auto iter = illegalImportsToLegal.find(curr->func);
146+
if (iter == illegalImportsToLegal.end()) {
154147
return;
155148
}
156149

157-
curr->func = iter->second;
150+
curr->func = iter->second->name;
151+
// TODO: Make this exact.
152+
curr->type = Type(iter->second->type, NonNullable);
158153
}
159154
};
160155

161-
Fixer fixer(&illegalImportsToLegal);
156+
Fixer fixer(illegalImportsToLegal);
162157
fixer.run(getPassRunner(), module);
163158
fixer.runOnModuleCode(getPassRunner(), module);
164159

@@ -174,7 +169,7 @@ struct LegalizeJSInterface : public Pass {
174169

175170
private:
176171
// map of illegal to legal names for imports
177-
std::map<Name, Name> illegalImportsToLegal;
172+
std::unordered_map<Name, Function*> illegalImportsToLegal;
178173
bool exportedHelpers = false;
179174
Function* getTempRet0 = nullptr;
180175
Function* setTempRet0 = nullptr;
@@ -274,7 +269,7 @@ struct LegalizeJSInterface : public Pass {
274269

275270
// wasm calls the import, so it must call a stub that calls the actual legal
276271
// JS import
277-
Name makeLegalStubForCalledImport(Function* im, Module* module) {
272+
Function* makeLegalStubForCalledImport(Function* im, Module* module) {
278273
Builder builder(*module);
279274
auto legalIm = std::make_unique<Function>();
280275
legalIm->name = Name(std::string("legalimport$") + im->name.toString());
@@ -315,14 +310,14 @@ struct LegalizeJSInterface : public Pass {
315310
}
316311
legalIm->type = Signature(Type(params), call->type);
317312

318-
const auto& stubName = stub->name;
319-
if (!module->getFunctionOrNull(stubName)) {
313+
auto* stubPtr = stub.get();
314+
if (!module->getFunctionOrNull(stub->name)) {
320315
module->addFunction(std::move(stub));
321316
}
322317
if (!module->getFunctionOrNull(legalIm->name)) {
323318
module->addFunction(std::move(legalIm));
324319
}
325-
return stubName;
320+
return stubPtr;
326321
}
327322

328323
static Function*

src/passes/PrintCallGraph.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ struct PrintCallGraph : public Pass {
9696
CallPrinter printer(module);
9797

9898
// Indirect Targets
99-
ElementUtils::iterAllElementFunctionNames(module, [&](Name& name) {
99+
ElementUtils::iterAllElementFunctionNames(module, [&](Name name) {
100100
auto* func = module->getFunction(name);
101101
o << " \"" << func->name << "\" [style=\"filled, rounded\"];\n";
102102
});

src/passes/RemoveImports.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ struct RemoveImports : public WalkerPass<PostWalker<RemoveImports>> {
5151
// Do not remove names referenced in a table
5252
std::set<Name> indirectNames;
5353
ElementUtils::iterAllElementFunctionNames(
54-
curr, [&](Name& name) { indirectNames.insert(name); });
54+
curr, [&](Name name) { indirectNames.insert(name); });
5555
for (auto& name : names) {
5656
if (indirectNames.find(name) == indirectNames.end()) {
5757
curr->removeFunction(name);

0 commit comments

Comments
 (0)