Skip to content

Commit 71e681d

Browse files
committed
Fix methodMeta error
1 parent d6b462d commit 71e681d

File tree

4 files changed

+88
-10
lines changed

4 files changed

+88
-10
lines changed

extension/wasm/test/CMakeLists.txt

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@
1111

1212
set(MODELS_DIR ${CMAKE_CURRENT_BINARY_DIR}/models/)
1313

14+
add_custom_command(
15+
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/models/test.pte
16+
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../../..
17+
COMMAND python3 -m extension.wasm.test.test_model
18+
${CMAKE_CURRENT_BINARY_DIR}/models/test.pte
19+
)
20+
1421
add_custom_command(
1522
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/models/add_mul.pte
1623
${CMAKE_CURRENT_BINARY_DIR}/models/add.pte
@@ -23,8 +30,9 @@ add_custom_command(
2330
)
2431

2532
add_custom_target(
26-
executorch_wasm_test_models DEPENDS ${MODELS_DIR}/add_mul.pte
27-
${MODELS_DIR}/add.pte
33+
executorch_wasm_test_models
34+
DEPENDS ${MODELS_DIR}/add_mul.pte ${MODELS_DIR}/add.pte
35+
${MODELS_DIR}/test.pte
2836
)
2937

3038
add_custom_command(

extension/wasm/test/test_model.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import sys
2+
3+
import torch
4+
from executorch.exir import to_edge_transform_and_lower
5+
from torch.export import export
6+
7+
8+
class IndexModel(torch.nn.Module):
9+
def forward(self, x, n):
10+
return x[n]
11+
12+
13+
class AddAllModel(torch.nn.Module):
14+
def forward(self, x, n):
15+
return x, n, x + n
16+
17+
18+
if __name__ == "__main__":
19+
output_filepath = sys.argv[1] if len(sys.argv) > 1 else "test.pte"
20+
indexModel = IndexModel().eval()
21+
addAllModel = AddAllModel().eval()
22+
23+
exported_index = export(indexModel, (torch.randn([3]), 1))
24+
exported_add_all = export(addAllModel, (torch.randn([2, 2]), 1))
25+
edge = to_edge_transform_and_lower(
26+
{
27+
"forward": exported_index,
28+
"index": exported_index,
29+
"add_all": exported_add_all,
30+
}
31+
)
32+
et = edge.to_executorch()
33+
with open(output_filepath, "wb") as file:
34+
file.write(et.buffer)

extension/wasm/test/unittests.js

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,13 @@ describe("Module", () => {
105105
module.delete();
106106
});
107107

108+
test("multiple methods", () => {
109+
const module = et.Module.load("test.pte");
110+
const methods = module.getMethods();
111+
expect(methods).toEqual(expect.arrayContaining(["forward", "index", "add_all"]));
112+
module.delete();
113+
});
114+
108115
test("loadMethod forward", () => {
109116
const module = et.Module.load("add.pte");
110117
expect(() => module.loadMethod("forward")).not.toThrow();
@@ -224,6 +231,25 @@ describe("Module", () => {
224231
});
225232
module.delete();
226233
});
234+
235+
test("non-tensor in input", () => {
236+
const module = et.Module.load("test.pte");
237+
const methodMeta = module.getMethodMeta("add_all");
238+
expect(methodMeta.inputTags).toEqual([et.Tag.Tensor, et.Tag.Int]);
239+
expect(methodMeta.inputTensorMeta[0]).not.toBeUndefined();
240+
expect(methodMeta.inputTensorMeta[1]).toBeUndefined();
241+
module.delete();
242+
});
243+
244+
test("non-tensor in output", () => {
245+
const module = et.Module.load("test.pte");
246+
const methodMeta = module.getMethodMeta("add_all");
247+
expect(methodMeta.outputTags).toEqual([et.Tag.Tensor, et.Tag.Int, et.Tag.Tensor]);
248+
expect(methodMeta.outputTensorMeta[0]).not.toBeUndefined();
249+
expect(methodMeta.outputTensorMeta[1]).toBeUndefined();
250+
expect(methodMeta.outputTensorMeta[2]).not.toBeUndefined();
251+
module.delete();
252+
});
227253
});
228254
});
229255

extension/wasm/wasm_bindings.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -475,16 +475,26 @@ struct ET_EXPERIMENTAL JsMethodMeta {
475475
val::array(),
476476
meta.num_instructions()};
477477
for (int i = 0; i < meta.num_inputs(); i++) {
478-
js_array_push(new_meta.input_tags, meta.input_tag(i).get());
479-
js_array_push(
480-
new_meta.input_tensor_meta,
481-
JsTensorInfo::from_tensor_info(meta.input_tensor_meta(i).get()));
478+
Tag tag = meta.input_tag(i).get();
479+
js_array_push(new_meta.input_tags, tag);
480+
if (tag == Tag::Tensor) {
481+
js_array_push(
482+
new_meta.input_tensor_meta,
483+
JsTensorInfo::from_tensor_info(meta.input_tensor_meta(i).get()));
484+
} else {
485+
js_array_push(new_meta.input_tensor_meta, val::undefined());
486+
}
482487
}
483488
for (int i = 0; i < meta.num_outputs(); i++) {
484-
js_array_push(new_meta.output_tags, meta.output_tag(i).get());
485-
js_array_push(
486-
new_meta.output_tensor_meta,
487-
JsTensorInfo::from_tensor_info(meta.output_tensor_meta(i).get()));
489+
Tag tag = meta.output_tag(i).get();
490+
js_array_push(new_meta.output_tags, tag);
491+
if (tag == Tag::Tensor) {
492+
js_array_push(
493+
new_meta.output_tensor_meta,
494+
JsTensorInfo::from_tensor_info(meta.output_tensor_meta(i).get()));
495+
} else {
496+
js_array_push(new_meta.output_tensor_meta, val::undefined());
497+
}
488498
}
489499
for (int i = 0; i < meta.num_attributes(); i++) {
490500
js_array_push(

0 commit comments

Comments
 (0)