Skip to content

Commit 0701996

Browse files
committed
Added unit tests for wasm bindings
1 parent 5378272 commit 0701996

File tree

8 files changed

+271
-39
lines changed

8 files changed

+271
-39
lines changed

extension/wasm/CMakeLists.txt

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
cmake_minimum_required(VERSION 3.24) # 3.24 is required for WHOLE_ARCHIVE
2+
cmake_minimum_required(VERSION 3.24)
33

44
project(executorch_wasm)
55

@@ -33,15 +33,12 @@ list(
3333
extension_runner_util
3434
)
3535

36-
add_executable(executorch_wasm wasm_bindings.cpp)
36+
add_library(executorch_wasm OBJECT wasm_bindings.cpp)
3737

3838
target_compile_options(executorch_wasm PUBLIC ${_common_compile_options})
3939
target_include_directories(executorch_wasm PUBLIC ${_common_include_directories})
4040
target_link_libraries(executorch_wasm PUBLIC ${link_libraries})
41-
target_link_options(executorch_wasm PUBLIC -sALLOW_MEMORY_GROWTH=1 --embed-file=${CMAKE_CURRENT_SOURCE_DIR}/[email protected])
4241

43-
add_custom_target(executorch_wasm_test ALL
44-
COMMAND cp ${CMAKE_CURRENT_SOURCE_DIR}/test_module.js ${CMAKE_CURRENT_BINARY_DIR}/test_module.js
45-
DEPENDS executorch_wasm
46-
COMMENT "Copying test_module.js to build output directory"
47-
)
42+
if(BUILD_TESTING)
43+
add_subdirectory(test)
44+
endif()

extension/wasm/build_wasm.sh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,8 @@ emcmake cmake -DEXECUTORCH_BUILD_WASM=ON \
88
-DEXECUTORCH_BUILD_DEVTOOLS=ON \
99
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
1010
-DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \
11-
-DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON ..
12-
make executorch_wasm_test -j32
11+
-DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \
12+
-DBUILD_TESTING=ON \
13+
-DCMAKE_BUILD_TYPE=Release \
14+
..
15+
make executorch_wasm_tests -j32

extension/wasm/test/CMakeLists.txt

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
2+
set(MODELS_DIR ${CMAKE_CURRENT_BINARY_DIR}/models/)
3+
4+
add_custom_command(
5+
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/models/add_mul.pte ${CMAKE_CURRENT_BINARY_DIR}/models/add.pte
6+
COMMAND ${CMAKE_COMMAND} -E make_directory "${MODELS_DIR}"
7+
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../../..
8+
COMMAND python3 -m examples.portable.scripts.export --model_name="add_mul" --output_dir="${MODELS_DIR}"
9+
COMMAND python3 -m examples.portable.scripts.export --model_name="add" --output_dir="${MODELS_DIR}"
10+
)
11+
12+
add_custom_target(executorch_wasm_test_models DEPENDS ${MODELS_DIR}/add_mul.pte ${MODELS_DIR}/add.pte)
13+
14+
add_executable(executorch_wasm_test_lib)
15+
target_link_libraries(executorch_wasm_test_lib PUBLIC executorch_wasm)
16+
target_link_options(executorch_wasm_test_lib PUBLIC --embed-file "${MODELS_DIR}@/")
17+
add_dependencies(executorch_wasm_test_lib executorch_wasm_test_models)
18+
19+
add_custom_command(
20+
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/executorch_wasm.test.js
21+
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/executorch_wasm.test.js ${CMAKE_CURRENT_BINARY_DIR}/executorch_wasm.test.js
22+
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/executorch_wasm.test.js
23+
COMMENT "Copying executorch_wasm.test.js to build output directory"
24+
)
25+
26+
add_custom_command(
27+
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/package.json
28+
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/package.json ${CMAKE_CURRENT_BINARY_DIR}/package.json
29+
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/package.json
30+
COMMENT "Copying package.json to build output directory"
31+
)
32+
33+
add_custom_target(executorch_wasm_tests DEPENDS executorch_wasm_test_lib ${CMAKE_CURRENT_BINARY_DIR}/executorch_wasm.test.js ${CMAKE_CURRENT_BINARY_DIR}/package.json)
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
2+
let et;
3+
beforeAll((done) => {
4+
et = require("./executorch_wasm_test_lib");
5+
et.onRuntimeInitialized = () => {
6+
done();
7+
}
8+
});
9+
10+
describe("Tensor", () => {
11+
test("ones", () => {
12+
const tensor = et.FloatTensor.ones([2, 2]);
13+
expect(tensor.getData()).toEqual([1, 1, 1, 1]);
14+
expect(tensor.getSizes()).toEqual([2, 2]);
15+
tensor.delete();
16+
});
17+
18+
test("zeros", () => {
19+
const tensor = et.FloatTensor.zeros([2, 2]);
20+
expect(tensor.getData()).toEqual([0, 0, 0, 0]);
21+
expect(tensor.getSizes()).toEqual([2, 2]);
22+
tensor.delete();
23+
});
24+
25+
test("fromArray", () => {
26+
const tensor = et.FloatTensor.fromArray([1, 2, 3, 4], [2, 2]);
27+
expect(tensor.getData()).toEqual([1, 2, 3, 4]);
28+
expect(tensor.getSizes()).toEqual([2, 2]);
29+
tensor.delete();
30+
});
31+
32+
test("fromArray wrong size", () => {
33+
expect(() => et.FloatTensor.fromArray([1, 2, 3, 4], [3, 2])).toThrow();
34+
});
35+
36+
test("full", () => {
37+
const tensor = et.FloatTensor.full([2, 2], 3);
38+
expect(tensor.getData()).toEqual([3, 3, 3, 3]);
39+
expect(tensor.getSizes()).toEqual([2, 2]);
40+
tensor.delete();
41+
});
42+
});
43+
44+
describe("Module", () => {
45+
test("getMethods has foward", () => {
46+
const module = et.Module.load("add.pte");
47+
const methods = module.getMethods();
48+
expect(methods).toEqual(["forward"]);
49+
module.delete();
50+
});
51+
52+
test("loadMethod forward", () => {
53+
const module = et.Module.load("add.pte");
54+
expect(() => module.loadMethod("forward")).not.toThrow();
55+
module.delete();
56+
});
57+
58+
test("loadMethod does not exist", () => {
59+
const module = et.Module.load("add.pte");
60+
expect(() => module.loadMethod("does_not_exist")).toThrow();
61+
module.delete();
62+
});
63+
64+
describe("MethodMeta", () => {
65+
test("name is forward", () => {
66+
const module = et.Module.load("add_mul.pte");
67+
const methodMeta = module.getMethodMeta("forward");
68+
expect(methodMeta.name).toEqual("forward");
69+
methodMeta.delete();
70+
module.delete();
71+
});
72+
73+
test("numInputs is 3", () => {
74+
const module = et.Module.load("add_mul.pte");
75+
const methodMeta = module.getMethodMeta("forward");
76+
expect(methodMeta.numInputs).toEqual(3);
77+
methodMeta.delete();
78+
module.delete();
79+
});
80+
81+
test("method does not exist", () => {
82+
const module = et.Module.load("add_mul.pte");
83+
expect(() => module.getMethodMeta("does_not_exist")).toThrow();
84+
module.delete();
85+
});
86+
87+
describe("TensorInfo", () => {
88+
test("sizes is 2x2", () => {
89+
const module = et.Module.load("add_mul.pte");
90+
const methodMeta = module.getMethodMeta("forward");
91+
for (var i = 0; i < methodMeta.numInputs; i++) {
92+
const tensorInfo = methodMeta.inputTensorMeta(i);
93+
expect(tensorInfo.sizes).toEqual([2, 2]);
94+
tensorInfo.delete();
95+
}
96+
methodMeta.delete();
97+
module.delete();
98+
});
99+
100+
test("out of range", () => {
101+
const module = et.Module.load("add_mul.pte");
102+
const methodMeta = module.getMethodMeta("forward");
103+
expect(() => methodMeta.inputTensorMeta(3)).toThrow();
104+
methodMeta.delete();
105+
module.delete();
106+
});
107+
});
108+
});
109+
110+
describe("execute", () => {
111+
test("add normally", () => {
112+
const module = et.Module.load("add.pte");
113+
const inputs = [et.FloatTensor.ones([1]), et.FloatTensor.ones([1])];
114+
const output = module.execute("forward", inputs);
115+
116+
expect(output.length).toEqual(1);
117+
expect(output[0].getData()).toEqual([2]);
118+
expect(output[0].getSizes()).toEqual([1]);
119+
120+
inputs.forEach((input) => input.delete());
121+
output.forEach((output) => output.delete());
122+
module.delete();
123+
});
124+
125+
test("add_mul normally", () => {
126+
const module = et.Module.load("add_mul.pte");
127+
const inputs = [et.FloatTensor.ones([2, 2]), et.FloatTensor.ones([2, 2]), et.FloatTensor.ones([2, 2])];
128+
const output = module.execute("forward", inputs);
129+
130+
expect(output.length).toEqual(1);
131+
expect(output[0].getData()).toEqual([3, 3, 3, 3]);
132+
expect(output[0].getSizes()).toEqual([2, 2]);
133+
134+
inputs.forEach((input) => input.delete());
135+
output.forEach((output) => output.delete());
136+
module.delete();
137+
});
138+
139+
test("forward directly", () => {
140+
const module = et.Module.load("add_mul.pte");
141+
const inputs = [et.FloatTensor.ones([2, 2]), et.FloatTensor.ones([2, 2]), et.FloatTensor.ones([2, 2])];
142+
const output = module.forward(inputs);
143+
144+
expect(output.length).toEqual(1);
145+
expect(output[0].getData()).toEqual([3, 3, 3, 3]);
146+
expect(output[0].getSizes()).toEqual([2, 2]);
147+
148+
inputs.forEach((input) => input.delete());
149+
output.forEach((output) => output.delete());
150+
module.delete();
151+
});
152+
153+
test("wrong number of inputs", () => {
154+
const module = et.Module.load("add_mul.pte");
155+
const inputs = [et.FloatTensor.ones([2, 2]), et.FloatTensor.ones([2, 2])];
156+
expect(() => module.execute("forward", inputs)).toThrow();
157+
158+
inputs.forEach((input) => input.delete());
159+
module.delete();
160+
});
161+
162+
test("wrong input size", () => {
163+
const module = et.Module.load("add.pte");
164+
const inputs = [et.FloatTensor.ones([2, 1]), et.FloatTensor.ones([2, 1])];
165+
expect(() => module.execute("forward", inputs)).toThrow();
166+
167+
inputs.forEach((input) => input.delete());
168+
module.delete();
169+
});
170+
171+
test("wrong input type", () => {
172+
const module = et.Module.load("add.pte");
173+
const inputs = [et.FloatTensor.ones([1]), et.IntTensor.ones([1])];
174+
expect(() => module.execute("forward", inputs)).toThrow();
175+
176+
inputs.forEach((input) => input.delete());
177+
module.delete();
178+
});
179+
180+
test("method does not exist", () => {
181+
const module = et.Module.load("add.pte");
182+
const inputs = [et.FloatTensor.ones([1]), et.FloatTensor.ones([1])];
183+
expect(() => module.execute("does_not_exist", inputs)).toThrow();
184+
185+
inputs.forEach((input) => input.delete());
186+
module.delete();
187+
});
188+
189+
test("output tensor can be reused", () => {
190+
const module = et.Module.load("add_mul.pte");
191+
const inputs = [et.FloatTensor.ones([2, 2]), et.FloatTensor.ones([2, 2]), et.FloatTensor.ones([2, 2])];
192+
const output = module.forward(inputs);
193+
194+
expect(output.length).toEqual(1);
195+
expect(output[0].getData()).toEqual([3, 3, 3, 3]);
196+
expect(output[0].getSizes()).toEqual([2, 2]);
197+
198+
const inputs2 = [output[0], output[0], output[0]];
199+
const output2 = module.forward(inputs2);
200+
201+
expect(output2.length).toEqual(1);
202+
expect(output2[0].getData()).toEqual([21, 21, 21, 21]);
203+
expect(output2[0].getSizes()).toEqual([2, 2]);
204+
205+
inputs.forEach((input) => input.delete());
206+
output.forEach((output) => output.delete());
207+
output2.forEach((output) => output.delete());
208+
module.delete();
209+
});
210+
});
211+
});

extension/wasm/test/package.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"scripts": {
3+
"test": "jest"
4+
}
5+
}

extension/wasm/test_module.js

Lines changed: 0 additions & 27 deletions
This file was deleted.

extension/wasm/wasm_bindings.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
} \
2121
})
2222

23-
/// Throws a JavaScript Error with the provided message if `error` is not `Ok`.
23+
/// Throws a JavaScript Error with the provided message if `cond` is not `true`.
2424
#define THROW_IF_FALSE(cond, message, ...) \
2525
({ \
2626
if ET_UNLIKELY (!(cond)) { \
@@ -62,6 +62,7 @@ inline void assert_valid_numel(
6262
data.size());
6363
}
6464

65+
// Base class for all JS Tensor types. Subclasses are not exposed to JS.
6566
class JsBaseTensor {
6667
public:
6768
virtual ~JsBaseTensor() = default;
@@ -86,6 +87,7 @@ class JsBaseTensor {
8687
}
8788
};
8889

90+
// Tensor that owns its own data. JS only has access to the static methods.
8991
template <typename T, aten::ScalarType S>
9092
class JsTensor final : public JsBaseTensor {
9193
public:
@@ -154,9 +156,10 @@ class JsTensor final : public JsBaseTensor {
154156

155157
#define JS_DECLARE_TENSOR_TYPE(T, NAME) \
156158
using Js##NAME##Tensor = JsTensor<T, aten::ScalarType::NAME>;
157-
158159
JS_FORALL_SUPPORTED_TENSOR_TYPES(JS_DECLARE_TENSOR_TYPE)
159160

161+
// Tensor that does not own its own data. It is a wrapper around a C++ Tensor.
162+
// This class is not exposed to JS.
160163
class JsOutputTensor final : public JsBaseTensor {
161164
public:
162165
JsOutputTensor() = delete;
@@ -180,6 +183,7 @@ class JsOutputTensor final : public JsBaseTensor {
180183
std::unique_ptr<Tensor> tensor_;
181184
};
182185

186+
// Converts JS value to EValue.
183187
EValue to_evalue(val v) {
184188
if (v.isNull()) {
185189
return EValue();
@@ -200,6 +204,7 @@ EValue to_evalue(val v) {
200204
}
201205
}
202206

207+
// Converts EValue to JS value.
203208
val to_val(EValue v) {
204209
if (v.isNone()) {
205210
return val::null();
@@ -221,6 +226,7 @@ val to_val(EValue v) {
221226
}
222227
}
223228

229+
// Wrapper around TensorInfo.
224230
class JsTensorInfo final {
225231
public:
226232
JsTensorInfo() = delete;
@@ -241,6 +247,7 @@ class JsTensorInfo final {
241247
std::unique_ptr<TensorInfo> tensor_info_;
242248
};
243249

250+
// Wrapper around MethodMeta.
244251
class JsMethodMeta final {
245252
public:
246253
JsMethodMeta() = delete;
@@ -275,6 +282,7 @@ class JsMethodMeta final {
275282
std::unique_ptr<MethodMeta> meta_;
276283
};
277284

285+
// Wrapper around extension/Module.
278286
class JsModule final {
279287
public:
280288
JsModule() = delete;

0 commit comments

Comments
 (0)