diff --git a/extension/wasm/test/CMakeLists.txt b/extension/wasm/test/CMakeLists.txt index fad2ab038cb..24e43500cbe 100644 --- a/extension/wasm/test/CMakeLists.txt +++ b/extension/wasm/test/CMakeLists.txt @@ -11,6 +11,13 @@ set(MODELS_DIR ${CMAKE_CURRENT_BINARY_DIR}/models/) +add_custom_command( + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/models/test.pte + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../../.. + COMMAND python3 -m extension.wasm.test.test_model + ${CMAKE_CURRENT_BINARY_DIR}/models/test.pte +) + add_custom_command( OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/models/add_mul.pte ${CMAKE_CURRENT_BINARY_DIR}/models/add.pte @@ -23,8 +30,9 @@ add_custom_command( ) add_custom_target( - executorch_wasm_test_models DEPENDS ${MODELS_DIR}/add_mul.pte - ${MODELS_DIR}/add.pte + executorch_wasm_test_models + DEPENDS ${MODELS_DIR}/add_mul.pte ${MODELS_DIR}/add.pte + ${MODELS_DIR}/test.pte ) add_custom_command( diff --git a/extension/wasm/test/test_model.py b/extension/wasm/test/test_model.py new file mode 100644 index 00000000000..11c50aa424b --- /dev/null +++ b/extension/wasm/test/test_model.py @@ -0,0 +1,34 @@ +import sys + +import torch +from executorch.exir import to_edge_transform_and_lower +from torch.export import export + + +class IndexModel(torch.nn.Module): + def forward(self, x, n): + return x[n] + + +class AddAllModel(torch.nn.Module): + def forward(self, x, n): + return x, n, x + n + + +if __name__ == "__main__": + output_filepath = sys.argv[1] if len(sys.argv) > 1 else "test.pte" + indexModel = IndexModel().eval() + addAllModel = AddAllModel().eval() + + exported_index = export(indexModel, (torch.randn([3]), 1)) + exported_add_all = export(addAllModel, (torch.randn([2, 2]), 1)) + edge = to_edge_transform_and_lower( + { + "forward": exported_index, + "index": exported_index, + "add_all": exported_add_all, + } + ) + et = edge.to_executorch() + with open(output_filepath, "wb") as file: + file.write(et.buffer) diff --git a/extension/wasm/test/unittests.js b/extension/wasm/test/unittests.js index 69dd899ce46..3d485c2e8b2 100644 --- a/extension/wasm/test/unittests.js +++ b/extension/wasm/test/unittests.js @@ -105,6 +105,13 @@ describe("Module", () => { module.delete(); }); + test("multiple methods", () => { + const module = et.Module.load("test.pte"); + const methods = module.getMethods(); + expect(methods).toEqual(expect.arrayContaining(["forward", "index", "add_all"])); + module.delete(); + }); + test("loadMethod forward", () => { const module = et.Module.load("add.pte"); expect(() => module.loadMethod("forward")).not.toThrow(); @@ -224,6 +231,25 @@ describe("Module", () => { }); module.delete(); }); + + test("non-tensor in input", () => { + const module = et.Module.load("test.pte"); + const methodMeta = module.getMethodMeta("add_all"); + expect(methodMeta.inputTags).toEqual([et.Tag.Tensor, et.Tag.Int]); + expect(methodMeta.inputTensorMeta[0]).not.toBeUndefined(); + expect(methodMeta.inputTensorMeta[1]).toBeUndefined(); + module.delete(); + }); + + test("non-tensor in output", () => { + const module = et.Module.load("test.pte"); + const methodMeta = module.getMethodMeta("add_all"); + expect(methodMeta.outputTags).toEqual([et.Tag.Tensor, et.Tag.Int, et.Tag.Tensor]); + expect(methodMeta.outputTensorMeta[0]).not.toBeUndefined(); + expect(methodMeta.outputTensorMeta[1]).toBeUndefined(); + expect(methodMeta.outputTensorMeta[2]).not.toBeUndefined(); + module.delete(); + }); }); }); diff --git a/extension/wasm/wasm_bindings.cpp b/extension/wasm/wasm_bindings.cpp index c1cadacddc0..1317c7cf294 100644 --- a/extension/wasm/wasm_bindings.cpp +++ b/extension/wasm/wasm_bindings.cpp @@ -475,16 +475,26 @@ struct ET_EXPERIMENTAL JsMethodMeta { val::array(), meta.num_instructions()}; for (int i = 0; i < meta.num_inputs(); i++) { - js_array_push(new_meta.input_tags, meta.input_tag(i).get()); - js_array_push( - new_meta.input_tensor_meta, - JsTensorInfo::from_tensor_info(meta.input_tensor_meta(i).get())); + Tag tag = meta.input_tag(i).get(); + js_array_push(new_meta.input_tags, tag); + if (tag == Tag::Tensor) { + js_array_push( + new_meta.input_tensor_meta, + JsTensorInfo::from_tensor_info(meta.input_tensor_meta(i).get())); + } else { + js_array_push(new_meta.input_tensor_meta, val::undefined()); + } } for (int i = 0; i < meta.num_outputs(); i++) { - js_array_push(new_meta.output_tags, meta.output_tag(i).get()); - js_array_push( - new_meta.output_tensor_meta, - JsTensorInfo::from_tensor_info(meta.output_tensor_meta(i).get())); + Tag tag = meta.output_tag(i).get(); + js_array_push(new_meta.output_tags, tag); + if (tag == Tag::Tensor) { + js_array_push( + new_meta.output_tensor_meta, + JsTensorInfo::from_tensor_info(meta.output_tensor_meta(i).get())); + } else { + js_array_push(new_meta.output_tensor_meta, val::undefined()); + } } for (int i = 0; i < meta.num_attributes(); i++) { js_array_push(