Skip to content

Commit ff8be27

Browse files
authored
Fix method meta error in wasm build (#13496)
### Summary An error would occur when trying to get the metadata for a method containing an input or output that isn't a tensor. This occurs because there wasn't a check for the input/output tag when generating the list of tensor metadata. Now the list of tensor metadata will have `undefined` if the input/output in that index is not a tensor. ### Test plan Added unit tests to test the method metadata for a method that has ints in its input and outputs. Added unit test to check if getMethod works if the module has multiple methods. Unit tests are in the CI but can be ran with ``` bash scripts/build_wasm_tests.sh cd cmake-out-wasm/extension/wasm/test/ npm test # after installing Jest ```
1 parent 7a98dd1 commit ff8be27

File tree

4 files changed

+88
-10
lines changed

4 files changed

+88
-10
lines changed

extension/wasm/test/CMakeLists.txt

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@
1111

1212
set(MODELS_DIR ${CMAKE_CURRENT_BINARY_DIR}/models/)
1313

14+
add_custom_command(
15+
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/models/test.pte
16+
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../../..
17+
COMMAND python3 -m extension.wasm.test.test_model
18+
${CMAKE_CURRENT_BINARY_DIR}/models/test.pte
19+
)
20+
1421
add_custom_command(
1522
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/models/add_mul.pte
1623
${CMAKE_CURRENT_BINARY_DIR}/models/add.pte
@@ -23,8 +30,9 @@ add_custom_command(
2330
)
2431

2532
add_custom_target(
26-
executorch_wasm_test_models DEPENDS ${MODELS_DIR}/add_mul.pte
27-
${MODELS_DIR}/add.pte
33+
executorch_wasm_test_models
34+
DEPENDS ${MODELS_DIR}/add_mul.pte ${MODELS_DIR}/add.pte
35+
${MODELS_DIR}/test.pte
2836
)
2937

3038
add_custom_command(

extension/wasm/test/test_model.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import sys
2+
3+
import torch
4+
from executorch.exir import to_edge_transform_and_lower
5+
from torch.export import export
6+
7+
8+
class IndexModel(torch.nn.Module):
9+
def forward(self, x, n):
10+
return x[n]
11+
12+
13+
class AddAllModel(torch.nn.Module):
14+
def forward(self, x, n):
15+
return x, n, x + n
16+
17+
18+
if __name__ == "__main__":
19+
output_filepath = sys.argv[1] if len(sys.argv) > 1 else "test.pte"
20+
indexModel = IndexModel().eval()
21+
addAllModel = AddAllModel().eval()
22+
23+
exported_index = export(indexModel, (torch.randn([3]), 1))
24+
exported_add_all = export(addAllModel, (torch.randn([2, 2]), 1))
25+
edge = to_edge_transform_and_lower(
26+
{
27+
"forward": exported_index,
28+
"index": exported_index,
29+
"add_all": exported_add_all,
30+
}
31+
)
32+
et = edge.to_executorch()
33+
with open(output_filepath, "wb") as file:
34+
file.write(et.buffer)

extension/wasm/test/unittests.js

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,13 @@ describe("Module", () => {
105105
module.delete();
106106
});
107107

108+
test("multiple methods", () => {
109+
const module = et.Module.load("test.pte");
110+
const methods = module.getMethods();
111+
expect(methods).toEqual(expect.arrayContaining(["forward", "index", "add_all"]));
112+
module.delete();
113+
});
114+
108115
test("loadMethod forward", () => {
109116
const module = et.Module.load("add.pte");
110117
expect(() => module.loadMethod("forward")).not.toThrow();
@@ -224,6 +231,25 @@ describe("Module", () => {
224231
});
225232
module.delete();
226233
});
234+
235+
test("non-tensor in input", () => {
236+
const module = et.Module.load("test.pte");
237+
const methodMeta = module.getMethodMeta("add_all");
238+
expect(methodMeta.inputTags).toEqual([et.Tag.Tensor, et.Tag.Int]);
239+
expect(methodMeta.inputTensorMeta[0]).not.toBeUndefined();
240+
expect(methodMeta.inputTensorMeta[1]).toBeUndefined();
241+
module.delete();
242+
});
243+
244+
test("non-tensor in output", () => {
245+
const module = et.Module.load("test.pte");
246+
const methodMeta = module.getMethodMeta("add_all");
247+
expect(methodMeta.outputTags).toEqual([et.Tag.Tensor, et.Tag.Int, et.Tag.Tensor]);
248+
expect(methodMeta.outputTensorMeta[0]).not.toBeUndefined();
249+
expect(methodMeta.outputTensorMeta[1]).toBeUndefined();
250+
expect(methodMeta.outputTensorMeta[2]).not.toBeUndefined();
251+
module.delete();
252+
});
227253
});
228254
});
229255

extension/wasm/wasm_bindings.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -475,16 +475,26 @@ struct ET_EXPERIMENTAL JsMethodMeta {
475475
val::array(),
476476
meta.num_instructions()};
477477
for (int i = 0; i < meta.num_inputs(); i++) {
478-
js_array_push(new_meta.input_tags, meta.input_tag(i).get());
479-
js_array_push(
480-
new_meta.input_tensor_meta,
481-
JsTensorInfo::from_tensor_info(meta.input_tensor_meta(i).get()));
478+
Tag tag = meta.input_tag(i).get();
479+
js_array_push(new_meta.input_tags, tag);
480+
if (tag == Tag::Tensor) {
481+
js_array_push(
482+
new_meta.input_tensor_meta,
483+
JsTensorInfo::from_tensor_info(meta.input_tensor_meta(i).get()));
484+
} else {
485+
js_array_push(new_meta.input_tensor_meta, val::undefined());
486+
}
482487
}
483488
for (int i = 0; i < meta.num_outputs(); i++) {
484-
js_array_push(new_meta.output_tags, meta.output_tag(i).get());
485-
js_array_push(
486-
new_meta.output_tensor_meta,
487-
JsTensorInfo::from_tensor_info(meta.output_tensor_meta(i).get()));
489+
Tag tag = meta.output_tag(i).get();
490+
js_array_push(new_meta.output_tags, tag);
491+
if (tag == Tag::Tensor) {
492+
js_array_push(
493+
new_meta.output_tensor_meta,
494+
JsTensorInfo::from_tensor_info(meta.output_tensor_meta(i).get()));
495+
} else {
496+
js_array_push(new_meta.output_tensor_meta, val::undefined());
497+
}
488498
}
489499
for (int i = 0; i < meta.num_attributes(); i++) {
490500
js_array_push(

0 commit comments

Comments
 (0)