Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions extension/wasm/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@

set(MODELS_DIR ${CMAKE_CURRENT_BINARY_DIR}/models/)

add_custom_command(
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/models/test.pte
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../../..
COMMAND python3 -m extension.wasm.test.test_model
${CMAKE_CURRENT_BINARY_DIR}/models/test.pte
)

add_custom_command(
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/models/add_mul.pte
${CMAKE_CURRENT_BINARY_DIR}/models/add.pte
Expand All @@ -23,8 +30,9 @@ add_custom_command(
)

add_custom_target(
executorch_wasm_test_models DEPENDS ${MODELS_DIR}/add_mul.pte
${MODELS_DIR}/add.pte
executorch_wasm_test_models
DEPENDS ${MODELS_DIR}/add_mul.pte ${MODELS_DIR}/add.pte
${MODELS_DIR}/test.pte
)

add_custom_command(
Expand Down
34 changes: 34 additions & 0 deletions extension/wasm/test/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import sys

import torch
from executorch.exir import to_edge_transform_and_lower
from torch.export import export


class IndexModel(torch.nn.Module):
def forward(self, x, n):
return x[n]


class AddAllModel(torch.nn.Module):
def forward(self, x, n):
return x, n, x + n


if __name__ == "__main__":
output_filepath = sys.argv[1] if len(sys.argv) > 1 else "test.pte"
indexModel = IndexModel().eval()
addAllModel = AddAllModel().eval()

exported_index = export(indexModel, (torch.randn([3]), 1))
exported_add_all = export(addAllModel, (torch.randn([2, 2]), 1))
edge = to_edge_transform_and_lower(
{
"forward": exported_index,
"index": exported_index,
"add_all": exported_add_all,
}
)
et = edge.to_executorch()
with open(output_filepath, "wb") as file:
file.write(et.buffer)
26 changes: 26 additions & 0 deletions extension/wasm/test/unittests.js
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ describe("Module", () => {
module.delete();
});

test("multiple methods", () => {
const module = et.Module.load("test.pte");
const methods = module.getMethods();
expect(methods).toEqual(expect.arrayContaining(["forward", "index", "add_all"]));
module.delete();
});

test("loadMethod forward", () => {
const module = et.Module.load("add.pte");
expect(() => module.loadMethod("forward")).not.toThrow();
Expand Down Expand Up @@ -224,6 +231,25 @@ describe("Module", () => {
});
module.delete();
});

test("non-tensor in input", () => {
const module = et.Module.load("test.pte");
const methodMeta = module.getMethodMeta("add_all");
expect(methodMeta.inputTags).toEqual([et.Tag.Tensor, et.Tag.Int]);
expect(methodMeta.inputTensorMeta[0]).not.toBeUndefined();
expect(methodMeta.inputTensorMeta[1]).toBeUndefined();
module.delete();
});

test("non-tensor in output", () => {
const module = et.Module.load("test.pte");
const methodMeta = module.getMethodMeta("add_all");
expect(methodMeta.outputTags).toEqual([et.Tag.Tensor, et.Tag.Int, et.Tag.Tensor]);
expect(methodMeta.outputTensorMeta[0]).not.toBeUndefined();
expect(methodMeta.outputTensorMeta[1]).toBeUndefined();
expect(methodMeta.outputTensorMeta[2]).not.toBeUndefined();
module.delete();
});
});
});

Expand Down
26 changes: 18 additions & 8 deletions extension/wasm/wasm_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,16 +475,26 @@ struct ET_EXPERIMENTAL JsMethodMeta {
val::array(),
meta.num_instructions()};
for (int i = 0; i < meta.num_inputs(); i++) {
js_array_push(new_meta.input_tags, meta.input_tag(i).get());
js_array_push(
new_meta.input_tensor_meta,
JsTensorInfo::from_tensor_info(meta.input_tensor_meta(i).get()));
Tag tag = meta.input_tag(i).get();
js_array_push(new_meta.input_tags, tag);
if (tag == Tag::Tensor) {
js_array_push(
new_meta.input_tensor_meta,
JsTensorInfo::from_tensor_info(meta.input_tensor_meta(i).get()));
} else {
js_array_push(new_meta.input_tensor_meta, val::undefined());
}
}
for (int i = 0; i < meta.num_outputs(); i++) {
js_array_push(new_meta.output_tags, meta.output_tag(i).get());
js_array_push(
new_meta.output_tensor_meta,
JsTensorInfo::from_tensor_info(meta.output_tensor_meta(i).get()));
Tag tag = meta.output_tag(i).get();
js_array_push(new_meta.output_tags, tag);
if (tag == Tag::Tensor) {
js_array_push(
new_meta.output_tensor_meta,
JsTensorInfo::from_tensor_info(meta.output_tensor_meta(i).get()));
} else {
js_array_push(new_meta.output_tensor_meta, val::undefined());
}
}
for (int i = 0; i < meta.num_attributes(); i++) {
js_array_push(
Expand Down
Loading