Skip to content

Commit 91683d9

Browse files
committed
Limit loop unrolling for argument and return values.
1 parent bea7880 commit 91683d9

File tree

1 file changed

+81
-49
lines changed

1 file changed

+81
-49
lines changed

wasm/src/org.graalvm.wasm/src/org/graalvm/wasm/api/InteropCallAdapterNode.java

Lines changed: 81 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353

5454
import com.oracle.truffle.api.CallTarget;
5555
import com.oracle.truffle.api.CompilerAsserts;
56+
import com.oracle.truffle.api.CompilerDirectives;
5657
import com.oracle.truffle.api.frame.VirtualFrame;
5758
import com.oracle.truffle.api.interop.ArityException;
5859
import com.oracle.truffle.api.interop.UnsupportedTypeException;
@@ -71,6 +72,9 @@
7172
* the call target to be reused by different functions of the same type (equivalence class).
7273
*/
7374
public final class InteropCallAdapterNode extends RootNode {
75+
76+
private static final int MAX_UNROLL = 32;
77+
7478
private final SymbolTable.FunctionType functionType;
7579
private final BranchProfile errorBranch = BranchProfile.create();
7680
@Child private WasmIndirectCallNode callNode;
@@ -105,59 +109,72 @@ public Object execute(VirtualFrame frame) {
105109
}
106110
}
107111

108-
@ExplodeLoop
109112
private Object[] validateArguments(Object[] arguments, int offset) throws ArityException, UnsupportedTypeException {
110113
final byte[] paramTypes = functionType.paramTypes();
111114
final int paramCount = paramTypes.length;
112115
CompilerAsserts.partialEvaluationConstant(paramCount);
113116
if (arguments.length - offset != paramCount) {
114117
throw ArityException.create(paramCount, paramCount, arguments.length - offset);
115118
}
119+
if (CompilerDirectives.inCompiledCode() && paramCount <= MAX_UNROLL) {
120+
validateArgumentsUnroll(arguments, offset, paramTypes, paramCount);
121+
} else {
122+
for (int i = 0; i < paramCount; i++) {
123+
validateArgument(arguments, offset, paramTypes, i);
124+
}
125+
}
126+
return arguments;
127+
}
128+
129+
@ExplodeLoop
130+
private static void validateArgumentsUnroll(Object[] arguments, int offset, byte[] paramTypes, int paramCount) throws UnsupportedTypeException {
116131
for (int i = 0; i < paramCount; i++) {
117-
byte paramType = paramTypes[i];
118-
Object value = arguments[i + offset];
119-
switch (paramType) {
120-
case WasmType.I32_TYPE -> {
121-
if (value instanceof Integer) {
122-
continue;
123-
}
124-
}
125-
case WasmType.I64_TYPE -> {
126-
if (value instanceof Long) {
127-
continue;
128-
}
132+
validateArgument(arguments, offset, paramTypes, i);
133+
}
134+
}
135+
136+
private static void validateArgument(Object[] arguments, int offset, byte[] paramTypes, int i) throws UnsupportedTypeException {
137+
byte paramType = paramTypes[i];
138+
Object value = arguments[i + offset];
139+
switch (paramType) {
140+
case WasmType.I32_TYPE -> {
141+
if (value instanceof Integer) {
142+
return;
129143
}
130-
case WasmType.F32_TYPE -> {
131-
if (value instanceof Float) {
132-
continue;
133-
}
144+
}
145+
case WasmType.I64_TYPE -> {
146+
if (value instanceof Long) {
147+
return;
134148
}
135-
case WasmType.F64_TYPE -> {
136-
if (value instanceof Double) {
137-
continue;
138-
}
149+
}
150+
case WasmType.F32_TYPE -> {
151+
if (value instanceof Float) {
152+
return;
139153
}
140-
case WasmType.V128_TYPE -> {
141-
if (value instanceof Vector128) {
142-
continue;
143-
}
154+
}
155+
case WasmType.F64_TYPE -> {
156+
if (value instanceof Double) {
157+
return;
144158
}
145-
case WasmType.FUNCREF_TYPE -> {
146-
if (value instanceof WasmFunctionInstance || value == WasmConstant.NULL) {
147-
continue;
148-
}
159+
}
160+
case WasmType.V128_TYPE -> {
161+
if (value instanceof Vector128) {
162+
return;
149163
}
150-
case WasmType.EXTERNREF_TYPE -> {
151-
continue;
164+
}
165+
case WasmType.FUNCREF_TYPE -> {
166+
if (value instanceof WasmFunctionInstance || value == WasmConstant.NULL) {
167+
return;
152168
}
153-
default -> throw WasmException.create(Failure.UNKNOWN_TYPE);
154169
}
155-
throw UnsupportedTypeException.create(arguments);
170+
case WasmType.EXTERNREF_TYPE -> {
171+
return;
172+
}
173+
default -> throw WasmException.create(Failure.UNKNOWN_TYPE);
156174
}
157-
return arguments;
175+
throw UnsupportedTypeException.create(arguments);
158176
}
159177

160-
@ExplodeLoop
161178
private Object multiValueStackAsArray(WasmLanguage language) {
162179
final var multiValueStack = language.multiValueStack();
163180
final long[] primitiveMultiValueStack = multiValueStack.primitiveStack();
@@ -168,24 +185,39 @@ private Object multiValueStackAsArray(WasmLanguage language) {
168185
assert objectMultiValueStack.length >= resultCount;
169186
final Object[] values = new Object[resultCount];
170187
CompilerAsserts.partialEvaluationConstant(resultCount);
171-
for (int i = 0; i < resultCount; i++) {
172-
byte resultType = resultTypes[i];
173-
values[i] = switch (resultType) {
174-
case WasmType.I32_TYPE -> (int) primitiveMultiValueStack[i];
175-
case WasmType.I64_TYPE -> primitiveMultiValueStack[i];
176-
case WasmType.F32_TYPE -> Float.intBitsToFloat((int) primitiveMultiValueStack[i]);
177-
case WasmType.F64_TYPE -> Double.longBitsToDouble(primitiveMultiValueStack[i]);
178-
case WasmType.V128_TYPE, WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE -> {
179-
Object obj = objectMultiValueStack[i];
180-
objectMultiValueStack[i] = null;
181-
yield obj;
182-
}
183-
default -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL);
184-
};
188+
if (CompilerDirectives.inCompiledCode() && resultCount <= MAX_UNROLL) {
189+
popMultiValueResultUnroll(values, primitiveMultiValueStack, objectMultiValueStack, resultTypes, resultCount);
190+
} else {
191+
for (int i = 0; i < resultCount; i++) {
192+
values[i] = popMultiValueResult(primitiveMultiValueStack, objectMultiValueStack, resultTypes, i);
193+
}
185194
}
186195
return InteropArray.create(values);
187196
}
188197

198+
@ExplodeLoop
199+
private static void popMultiValueResultUnroll(Object[] values, long[] primitiveMultiValueStack, Object[] objectMultiValueStack, byte[] resultTypes, int resultCount) {
200+
for (int i = 0; i < resultCount; i++) {
201+
values[i] = popMultiValueResult(primitiveMultiValueStack, objectMultiValueStack, resultTypes, i);
202+
}
203+
}
204+
205+
private static Object popMultiValueResult(long[] primitiveMultiValueStack, Object[] objectMultiValueStack, byte[] resultTypes, int i) {
206+
final byte resultType = resultTypes[i];
207+
return switch (resultType) {
208+
case WasmType.I32_TYPE -> (int) primitiveMultiValueStack[i];
209+
case WasmType.I64_TYPE -> primitiveMultiValueStack[i];
210+
case WasmType.F32_TYPE -> Float.intBitsToFloat((int) primitiveMultiValueStack[i]);
211+
case WasmType.F64_TYPE -> Double.longBitsToDouble(primitiveMultiValueStack[i]);
212+
case WasmType.V128_TYPE, WasmType.FUNCREF_TYPE, WasmType.EXTERNREF_TYPE -> {
213+
Object obj = objectMultiValueStack[i];
214+
objectMultiValueStack[i] = null;
215+
yield obj;
216+
}
217+
default -> throw WasmException.create(Failure.UNSPECIFIED_INTERNAL);
218+
};
219+
}
220+
189221
// TODO: Do we need the 3 overrides below?
190222
@Override
191223
public String getName() {

0 commit comments

Comments
 (0)