Skip to content

Commit bea7880

Browse files
committed
Implement wasm function interop call adapter for argument and return value handling.
Fixes a compilation error in WasmFunctionInstance.execute due to WasmLanguage not being PE-constant when accessing the multi-value result stack.
1 parent 682668c commit bea7880

File tree

6 files changed

+259
-105
lines changed

6 files changed

+259
-105
lines changed

wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/Linker.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,8 @@ void resolveFunctionImport(WasmContext context, WasmInstance instance, WasmFunct
469469
}
470470

471471
void resolveFunctionExport(WasmModule module, int functionIndex, String exportedFunctionName) {
472-
final ImportDescriptor importDescriptor = module.symbolTable().function(functionIndex).importDescriptor();
472+
final WasmFunction function = module.symbolTable().function(functionIndex);
473+
final ImportDescriptor importDescriptor = function.importDescriptor();
473474
final Sym[] dependencies = (importDescriptor != null) ? new Sym[]{new ImportFunctionSym(module.name(), importDescriptor, functionIndex)} : ResolutionDag.NO_DEPENDENCIES;
474475
resolutionDag.resolveLater(new ExportFunctionSym(module.name(), exportedFunctionName), dependencies, NO_RESOLVE_ACTION);
475476
}

wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/SymbolTable.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,9 @@ public abstract class SymbolTable {
8686
private static final int NO_EQUIVALENCE_CLASS = 0;
8787
static final int FIRST_EQUIVALENCE_CLASS = NO_EQUIVALENCE_CLASS + 1;
8888

89-
public static class FunctionType {
90-
private final byte[] paramTypes;
91-
private final byte[] resultTypes;
89+
public static final class FunctionType {
90+
@CompilationFinal(dimensions = 1) private final byte[] paramTypes;
91+
@CompilationFinal(dimensions = 1) private final byte[] resultTypes;
9292
private final int hashCode;
9393

9494
FunctionType(byte[] paramTypes, byte[] resultTypes) {
@@ -112,10 +112,9 @@ public int hashCode() {
112112

113113
@Override
114114
public boolean equals(Object object) {
115-
if (!(object instanceof FunctionType)) {
115+
if (!(object instanceof FunctionType that)) {
116116
return false;
117117
}
118-
FunctionType that = (FunctionType) object;
119118
if (this.paramTypes.length != that.paramTypes.length) {
120119
return false;
121120
}
@@ -146,7 +145,7 @@ public String toString() {
146145
for (int i = 0; i < resultTypes.length; i++) {
147146
resultNames[i] = WasmType.toString(resultTypes[i]);
148147
}
149-
return Arrays.toString(paramNames) + " -> " + Arrays.toString(resultNames);
148+
return "(" + String.join(" ", paramNames) + ")->(" + String.join(" ", resultNames) + ")";
150149
}
151150
}
152151

wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmFunction.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* The Universal Permissive License (UPL), Version 1.0
@@ -128,6 +128,14 @@ public String importedFunctionName() {
128128
return isImported() ? importDescriptor.memberName() : null;
129129
}
130130

131+
public String exportedFunctionName() {
132+
return symbolTable.exportedFunctionName(index);
133+
}
134+
135+
public boolean isExported() {
136+
return exportedFunctionName() != null;
137+
}
138+
131139
public int typeIndex() {
132140
return typeIndex;
133141
}

wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmFunctionInstance.java

Lines changed: 16 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -42,25 +42,22 @@
4242

4343
import java.util.Objects;
4444

45-
import org.graalvm.wasm.api.InteropArray;
46-
import org.graalvm.wasm.api.Vector128;
47-
import org.graalvm.wasm.exception.Failure;
48-
import org.graalvm.wasm.exception.WasmException;
45+
import org.graalvm.wasm.nodes.WasmCallNode;
4946
import org.graalvm.wasm.nodes.WasmIndirectCallNode;
5047

5148
import com.oracle.truffle.api.CallTarget;
5249
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
5350
import com.oracle.truffle.api.RootCallTarget;
5451
import com.oracle.truffle.api.TruffleContext;
5552
import com.oracle.truffle.api.dsl.Cached;
56-
import com.oracle.truffle.api.interop.ArityException;
53+
import com.oracle.truffle.api.dsl.ImportStatic;
5754
import com.oracle.truffle.api.interop.InteropLibrary;
5855
import com.oracle.truffle.api.interop.TruffleObject;
59-
import com.oracle.truffle.api.interop.UnsupportedTypeException;
6056
import com.oracle.truffle.api.library.CachedLibrary;
6157
import com.oracle.truffle.api.library.ExportLibrary;
6258
import com.oracle.truffle.api.library.ExportMessage;
6359

60+
@ImportStatic(WasmCallNode.class)
6461
@ExportLibrary(InteropLibrary.class)
6562
public final class WasmFunctionInstance extends EmbedderDataHolder implements TruffleObject {
6663

@@ -69,7 +66,15 @@ public final class WasmFunctionInstance extends EmbedderDataHolder implements Tr
6966
private final WasmFunction function;
7067
private final CallTarget target;
7168
private final TruffleContext truffleContext;
69+
/**
70+
* Stores the imported function object for {@link org.graalvm.wasm.api.ExecuteHostFunctionNode}.
71+
* Initialized during linking.
72+
*/
7273
private Object importedFunction;
74+
/**
75+
* Interop call adapter for exported functions, converting parameter and result values.
76+
*/
77+
private final CallTarget interopCallAdapter;
7378

7479
/**
7580
* Represents a call target that is a WebAssembly function or an imported function.
@@ -84,6 +89,7 @@ public WasmFunctionInstance(WasmContext context, WasmInstance moduleInstance, Wa
8489
this.function = Objects.requireNonNull(function, "function must be non-null");
8590
this.target = Objects.requireNonNull(target, "Call target must be non-null");
8691
this.truffleContext = context.environment().getContext();
92+
this.interopCallAdapter = context.language().interopCallAdapterFor(function.type());
8793
assert ((RootCallTarget) target).getRootNode().getLanguage(WasmLanguage.class) == context.language();
8894
}
8995

@@ -137,102 +143,15 @@ boolean isExecutable() {
137143
@ExportMessage
138144
Object execute(Object[] arguments,
139145
@CachedLibrary("this") InteropLibrary self,
140-
@Cached WasmIndirectCallNode callNode) throws ArityException, UnsupportedTypeException {
146+
@Cached WasmIndirectCallNode callNode) {
141147
TruffleContext c = getTruffleContext();
142148
Object prev = c.enter(self);
143149
try {
144-
Object result = callNode.execute(target, WasmArguments.create(moduleInstance, validateArguments(arguments)));
145-
146-
// For external calls of a WebAssembly function we have to materialize the multi-value
147-
// stack.
148-
// At this point the multi-value stack has already been populated, therefore, we don't
149-
// have to check the size of the multi-value stack.
150-
if (result == WasmConstant.MULTI_VALUE) {
151-
WasmLanguage language = context.language();
152-
assert language == WasmLanguage.get(null);
153-
return multiValueStackAsArray(language);
154-
}
155-
return result;
150+
CallTarget callAdapter = Objects.requireNonNull(this.interopCallAdapter);
151+
return callNode.execute(callAdapter, WasmArguments.create(this, arguments));
152+
// throws ArityException, UnsupportedTypeException
156153
} finally {
157154
c.leave(self, prev);
158155
}
159156
}
160-
161-
private Object[] validateArguments(Object[] arguments) throws ArityException, UnsupportedTypeException {
162-
if (function == null) {
163-
return arguments;
164-
}
165-
final int paramCount = function.paramCount();
166-
if (arguments.length != paramCount) {
167-
throw ArityException.create(paramCount, paramCount, arguments.length);
168-
}
169-
for (int i = 0; i < paramCount; i++) {
170-
byte paramType = function.paramTypeAt(i);
171-
Object value = arguments[i];
172-
switch (paramType) {
173-
case WasmType.I32_TYPE -> {
174-
if (value instanceof Integer) {
175-
continue;
176-
}
177-
}
178-
case WasmType.I64_TYPE -> {
179-
if (value instanceof Long) {
180-
continue;
181-
}
182-
}
183-
case WasmType.F32_TYPE -> {
184-
if (value instanceof Float) {
185-
continue;
186-
}
187-
}
188-
case WasmType.F64_TYPE -> {
189-
if (value instanceof Double) {
190-
continue;
191-
}
192-
}
193-
case WasmType.V128_TYPE -> {
194-
if (value instanceof Vector128) {
195-
continue;
196-
}
197-
}
198-
case WasmType.FUNCREF_TYPE -> {
199-
if (value instanceof WasmFunctionInstance || value == WasmConstant.NULL) {
200-
continue;
201-
}
202-
}
203-
case WasmType.EXTERNREF_TYPE -> {
204-
continue;
205-
}
206-
default -> throw WasmException.create(Failure.UNKNOWN_TYPE);
207-
}
208-
throw UnsupportedTypeException.create(arguments);
209-
}
210-
return arguments;
211-
}
212-
213-
private Object multiValueStackAsArray(WasmLanguage language) {
214-
final var multiValueStack = language.multiValueStack();
215-
final long[] primitiveMultiValueStack = multiValueStack.primitiveStack();
216-
final Object[] objectMultiValueStack = multiValueStack.objectStack();
217-
final int resultCount = function.resultCount();
218-
assert primitiveMultiValueStack.length >= resultCount;
219-
assert objectMultiValueStack.length >= resultCount;
220-
final Object[] values = new Object[resultCount];
221-
for (int i = 0; i < resultCount; i++) {
222-
byte resultType = function.resultTypeAt(i);
223-
values[i] = switch (resultType) {
224-
case WasmType.I32_TYPE -> (int) primitiveMultiValueStack[i];
225-
case WasmType.I64_TYPE -> primitiveMultiValueStack[i];
226-
case WasmType.F32_TYPE -> Float.intBitsToFloat((int) primitiveMultiValueStack[i]);
227-
case WasmType.F64_TYPE -> Double.longBitsToDouble(primitiveMultiValueStack[i]);
228-
case WasmType.V128_TYPE, WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE -> {
229-
Object obj = objectMultiValueStack[i];
230-
objectMultiValueStack[i] = null;
231-
yield obj;
232-
}
233-
default -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL);
234-
};
235-
}
236-
return InteropArray.create(values);
237-
}
238157
}

wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/WasmLanguage.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import org.graalvm.options.OptionDescriptors;
4949
import org.graalvm.options.OptionValues;
5050
import org.graalvm.polyglot.SandboxPolicy;
51+
import org.graalvm.wasm.api.InteropCallAdapterNode;
5152
import org.graalvm.wasm.api.JsConstants;
5253
import org.graalvm.wasm.api.WebAssembly;
5354
import org.graalvm.wasm.exception.WasmJsApiException;
@@ -56,6 +57,7 @@
5657
import org.graalvm.wasm.predefined.BuiltinModule;
5758

5859
import com.oracle.truffle.api.CallTarget;
60+
import com.oracle.truffle.api.CompilerAsserts;
5961
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
6062
import com.oracle.truffle.api.ContextThreadLocal;
6163
import com.oracle.truffle.api.RootCallTarget;
@@ -96,8 +98,10 @@ public final class WasmLanguage extends TruffleLanguage<WasmContext> {
9698

9799
private final Map<SymbolTable.FunctionType, Integer> equivalenceClasses = new ConcurrentHashMap<>();
98100
private int nextEquivalenceClass = SymbolTable.FIRST_EQUIVALENCE_CLASS;
101+
private final Map<SymbolTable.FunctionType, CallTarget> interopCallAdapters = new ConcurrentHashMap<>();
99102

100103
public int equivalenceClassFor(SymbolTable.FunctionType type) {
104+
CompilerAsserts.neverPartOfCompilation();
101105
Integer equivalenceClass = equivalenceClasses.get(type);
102106
if (equivalenceClass == null) {
103107
synchronized (this) {
@@ -112,6 +116,20 @@ public int equivalenceClassFor(SymbolTable.FunctionType type) {
112116
return equivalenceClass;
113117
}
114118

119+
/**
120+
* Gets or creates the interop call adapter for a function type. Always returns the same call
121+
* target for any particular type.
122+
*/
123+
public CallTarget interopCallAdapterFor(SymbolTable.FunctionType type) {
124+
CompilerAsserts.neverPartOfCompilation();
125+
CallTarget callAdapter = interopCallAdapters.get(type);
126+
if (callAdapter == null) {
127+
callAdapter = interopCallAdapters.computeIfAbsent(type,
128+
k -> new InteropCallAdapterNode(this, k).getCallTarget());
129+
}
130+
return callAdapter;
131+
}
132+
115133
@Override
116134
protected WasmContext createContext(Env env) {
117135
WasmContext context = new WasmContext(env, this);
@@ -249,6 +267,11 @@ protected boolean areOptionsCompatible(OptionValues firstOptions, OptionValues n
249267
}
250268
}
251269

270+
@SuppressWarnings("unchecked")
271+
public static <E extends Throwable> RuntimeException rethrow(Throwable ex) throws E {
272+
throw (E) ex;
273+
}
274+
252275
public MultiValueStack multiValueStack() {
253276
return multiValueStackThreadLocal.get();
254277
}

0 commit comments

Comments
 (0)