diff --git a/.ci/scripts/setup-emscripten.sh b/.ci/scripts/setup-emscripten.sh index 313072616f8..a4f4fd1a078 100644 --- a/.ci/scripts/setup-emscripten.sh +++ b/.ci/scripts/setup-emscripten.sh @@ -7,6 +7,13 @@ set -ex +# need version >= 17 +install_node() { + curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.3/install.sh | bash + source "$HOME/.nvm/nvm.sh" + nvm install 22 +} + install_emscripten() { git clone https://github.com/emscripten-core/emsdk.git pushd emsdk || return @@ -16,4 +23,5 @@ install_emscripten() { popd || return } +install_node install_emscripten diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index e221987f3ff..b697b4166e0 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -764,6 +764,41 @@ jobs: # Test selective build PYTHON_EXECUTABLE=python bash examples/wasm/test_build_wasm.sh + unittest-wasm-bindings: + name: unittest-wasm-bindings + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + permissions: + id-token: write + contents: read + strategy: + fail-fast: false + with: + runner: linux.2xlarge + docker-image: ci-image:executorch-ubuntu-22.04-clang12 + submodules: 'recursive' + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + timeout: 90 + script: | + # The generic Linux job chooses to use base env, not the one setup by the image + CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") + conda activate "${CONDA_ENV}" + + BUILD_TOOL="cmake" + PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool "${BUILD_TOOL}" + + # Install Node.js and Emscripten + source .ci/scripts/setup-emscripten.sh + + # Test selective build + bash scripts/build_wasm_tests.sh + + # Install Jest + cd cmake-out-wasm/extension/wasm/test + npm install --save-dev jest + + # Run unit test + npm test + unittest-nxp-neutron: uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main permissions: diff --git a/CMakeLists.txt b/CMakeLists.txt index 5c3787b3863..c66b7c8ca68 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -766,6 +766,10 @@ if(EXECUTORCH_BUILD_PYBIND) ) endif() +if(EXECUTORCH_BUILD_WASM) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/wasm) +endif() + if(EXECUTORCH_BUILD_EXTENSION_TRAINING) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/training) list(APPEND _executorch_extensions extension_training) diff --git a/extension/wasm/CMakeLists.txt b/extension/wasm/CMakeLists.txt new file mode 100644 index 00000000000..f6095c144ec --- /dev/null +++ b/extension/wasm/CMakeLists.txt @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Please this file formatted by running: +# ~~~ +# cmake-format -i CMakeLists.txt +# ~~~ + +cmake_minimum_required(VERSION 3.29) + +project(executorch_wasm) + +if(NOT CMAKE_CXX_STANDARD) + set(CMAKE_CXX_STANDARD 17) +endif() + +if(NOT EMSCRIPTEN) + message(FATAL_ERROR "Emscripten is required to build this target") +endif() + +# Source root directory for executorch. +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..) +endif() + +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) +set(_common_compile_options -Wno-deprecated-declarations -fPIC -Wall -Werror) +set(_common_include_directories ${EXECUTORCH_ROOT}/..) + +set(link_libraries) +list( + APPEND + link_libraries + embind + executorch_core + extension_data_loader + portable_ops_lib + extension_module_static + extension_tensor + extension_runner_util +) + +add_library(executorch_wasm OBJECT wasm_bindings.cpp) + +target_compile_options(executorch_wasm PUBLIC ${_common_compile_options}) +target_include_directories( + executorch_wasm PUBLIC ${_common_include_directories} +) +target_link_libraries(executorch_wasm PUBLIC ${link_libraries}) + +if(EXECUTORCH_BUILD_WASM_TESTS) + add_subdirectory(test) +endif() diff --git a/extension/wasm/test/CMakeLists.txt b/extension/wasm/test/CMakeLists.txt new file mode 100644 index 00000000000..02e4cb444a3 --- /dev/null +++ b/extension/wasm/test/CMakeLists.txt @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Please this file formatted by running: +# ~~~ +# cmake-format -i CMakeLists.txt +# ~~~ + +set(MODELS_DIR ${CMAKE_CURRENT_BINARY_DIR}/models/) + +add_custom_command( + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/models/add_mul.pte + ${CMAKE_CURRENT_BINARY_DIR}/models/add.pte + COMMAND ${CMAKE_COMMAND} -E make_directory "${MODELS_DIR}" + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../../.. + COMMAND python3 -m examples.portable.scripts.export --model_name="add_mul" + --output_dir="${MODELS_DIR}" + COMMAND python3 -m examples.portable.scripts.export --model_name="add" + --output_dir="${MODELS_DIR}" +) + +add_custom_target( + executorch_wasm_test_models DEPENDS ${MODELS_DIR}/add_mul.pte + ${MODELS_DIR}/add.pte +) + +add_custom_command( + OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/package.json + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/package.json + ${CMAKE_CURRENT_BINARY_DIR}/package.json + DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/package.json + COMMENT "Copying package.json to build output directory" +) + +add_custom_target( + executorch_wasm_test_package_json + DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/package.json +) + +add_executable(executorch_wasm_tests) +target_link_libraries(executorch_wasm_tests PUBLIC executorch_wasm) +target_link_options( + executorch_wasm_tests + PUBLIC + --embed-file + "${MODELS_DIR}@/" + --post-js + ${CMAKE_CURRENT_SOURCE_DIR}/unittests.js + -sASSERTIONS=2 +) +set_target_properties( + executorch_wasm_tests PROPERTIES OUTPUT_NAME "executorch_wasm.test" +) +set_property( + TARGET executorch_wasm_tests + APPEND + PROPERTY LINK_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/unittests.js +) +add_dependencies( + executorch_wasm_tests executorch_wasm_test_models + executorch_wasm_test_package_json +) diff --git a/extension/wasm/test/package.json b/extension/wasm/test/package.json new file mode 100644 index 00000000000..a25522fa51b --- /dev/null +++ b/extension/wasm/test/package.json @@ -0,0 +1,5 @@ +{ + "scripts": { + "test": "jest" + } +} diff --git a/extension/wasm/test/unittests.js b/extension/wasm/test/unittests.js new file mode 100644 index 00000000000..1eeadd193d8 --- /dev/null +++ b/extension/wasm/test/unittests.js @@ -0,0 +1,339 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +let et; +beforeAll((done) => { + et = Module; + et.onRuntimeInitialized = () => { + done(); + } +}); + +describe("Tensor", () => { + test("ones", () => { + const tensor = et.Tensor.ones([2, 2]); + expect(tensor.data).toEqual(new Float32Array([1, 1, 1, 1])); + expect(tensor.sizes).toEqual([2, 2]); + tensor.delete(); + }); + + test("zeros", () => { + const tensor = et.Tensor.zeros([2, 2]); + expect(tensor.data).toEqual(new Float32Array([0, 0, 0, 0])); + expect(tensor.sizes).toEqual([2, 2]); + tensor.delete(); + }); + + test("fromArray", () => { + const tensor = et.Tensor.fromArray([2, 2], [1, 2, 3, 4]); + expect(tensor.data).toEqual(new Float32Array([1, 2, 3, 4])); + expect(tensor.sizes).toEqual([2, 2]); + tensor.delete(); + }); + + test("fromGenerator", () => { + function* generator() { + yield* [1, 2, 3, 4]; + } + const tensor = et.Tensor.fromIter([2, 2], generator()); + expect(tensor.data).toEqual(new Float32Array([1, 2, 3, 4])); + expect(tensor.sizes).toEqual([2, 2]); + tensor.delete(); + }); + + test("fromArray wrong size", () => { + expect(() => et.Tensor.fromArray([3, 2], [1, 2, 3, 4])).toThrow(); + }); + + test("full", () => { + const tensor = et.Tensor.full([2, 2], 3); + expect(tensor.data).toEqual(new Float32Array([3, 3, 3, 3])); + expect(tensor.sizes).toEqual([2, 2]); + tensor.delete(); + }); + + test("scalar type", () => { + const tensor = et.Tensor.ones([2, 2]); + expect(tensor.scalarType).toEqual(et.ScalarType.Float); + tensor.delete(); + }); + + test("long tensor", () => { + const tensor = et.Tensor.ones([2, 2], et.ScalarType.Long); + expect(tensor.data).toEqual(new BigInt64Array([1n, 1n, 1n, 1n])); + expect(tensor.sizes).toEqual([2, 2]); + expect(tensor.scalarType).toEqual(et.ScalarType.Long); + tensor.delete(); + }); + + test("infer long tensor", () => { + // Number cannot be converted to Long, so we use BigInt instead. + const tensor = et.Tensor.fromArray([2, 2], [1n, 2n, 3n, 4n]); + expect(tensor.data).toEqual(new BigInt64Array([1n, 2n, 3n, 4n])); + expect(tensor.sizes).toEqual([2, 2]); + expect(tensor.scalarType).toEqual(et.ScalarType.Long); + tensor.delete(); + }); + + test("with dim order and strides", () => { + const tensor = et.Tensor.fromArray([2, 2], [1, 2, 3, 4], et.ScalarType.Float, [0, 1], [2, 1]); + expect(tensor.data).toEqual(new Float32Array([1, 2, 3, 4])); + expect(tensor.sizes).toEqual([2, 2]); + tensor.delete(); + }); + + test("incorrect dim order", () => { + expect(() => et.Tensor.fromArray([2, 2], [1, 2, 3, 4], et.ScalarType.Float, [1])).toThrow(); + expect(() => et.Tensor.fromArray([2, 2], [1, 2, 3, 4], et.ScalarType.Float, [1, 2])).toThrow(); + }); + + test("incorrect strides", () => { + expect(() => et.Tensor.fromArray([2, 2], [1, 2, 3, 4], et.ScalarType.Float, [1, 1], [2, 1])).toThrow(); + }); +}); + +describe("Module", () => { + test("getMethods has foward", () => { + const module = et.Module.load("add.pte"); + const methods = module.getMethods(); + expect(methods).toEqual(["forward"]); + module.delete(); + }); + + test("loadMethod forward", () => { + const module = et.Module.load("add.pte"); + expect(() => module.loadMethod("forward")).not.toThrow(); + module.delete(); + }); + + test("loadMethod does not exist", () => { + const module = et.Module.load("add.pte"); + expect(() => module.loadMethod("does_not_exist")).toThrow(); + module.delete(); + }); + + test("load from Uint8Array", () => { + const data = FS.readFile('add.pte'); + const module = et.Module.load(data); + const methods = module.getMethods(); + expect(methods).toEqual(["forward"]); + module.delete(); + }); + + test("load from ArrayBuffer", () => { + const data = FS.readFile('add.pte'); + const module = et.Module.load(data.buffer); + const methods = module.getMethods(); + expect(methods).toEqual(["forward"]); + module.delete(); + }); + + describe("MethodMeta", () => { + test("name is forward", () => { + const module = et.Module.load("add_mul.pte"); + const methodMeta = module.getMethodMeta("forward"); + expect(methodMeta.name).toEqual("forward"); + module.delete(); + }); + + test("inputs are tensors", () => { + const module = et.Module.load("add_mul.pte"); + const methodMeta = module.getMethodMeta("forward"); + expect(methodMeta.inputTags.length).toEqual(3); + expect(methodMeta.inputTags).toEqual([et.Tag.Tensor, et.Tag.Tensor, et.Tag.Tensor]); + module.delete(); + }); + + test("outputs are tensors", () => { + const module = et.Module.load("add_mul.pte"); + const methodMeta = module.getMethodMeta("forward"); + expect(methodMeta.outputTags.length).toEqual(1); + expect(methodMeta.outputTags).toEqual([et.Tag.Tensor]); + module.delete(); + }); + + test("num instructions is 2", () => { + const module = et.Module.load("add_mul.pte"); + const methodMeta = module.getMethodMeta("forward"); + expect(methodMeta.numInstructions).toEqual(2); + module.delete(); + }); + + test("method does not exist", () => { + const module = et.Module.load("add_mul.pte"); + expect(() => module.getMethodMeta("does_not_exist")).toThrow(); + module.delete(); + }); + + describe("TensorInfo", () => { + test("input sizes is 2x2", () => { + const module = et.Module.load("add_mul.pte"); + const methodMeta = module.getMethodMeta("forward"); + expect(methodMeta.inputTensorMeta.length).toEqual(3); + methodMeta.inputTensorMeta.forEach((tensorInfo) => { + expect(tensorInfo.sizes).toEqual([2, 2]); + }); + module.delete(); + }); + + test("output sizes is 2x2", () => { + const module = et.Module.load("add_mul.pte"); + const methodMeta = module.getMethodMeta("forward"); + expect(methodMeta.outputTensorMeta.length).toEqual(1); + expect(methodMeta.outputTensorMeta[0].sizes).toEqual([2, 2]); + module.delete(); + }); + + test("dim order is contiguous", () => { + const module = et.Module.load("add_mul.pte"); + const methodMeta = module.getMethodMeta("forward"); + methodMeta.inputTensorMeta.forEach((tensorInfo) => { + expect(tensorInfo.dimOrder).toEqual([0, 1]); + }); + module.delete(); + }); + + test("scalar type is float", () => { + const module = et.Module.load("add_mul.pte"); + const methodMeta = module.getMethodMeta("forward"); + methodMeta.inputTensorMeta.forEach((tensorInfo) => { + expect(tensorInfo.scalarType).toEqual(et.ScalarType.Float); + }); + module.delete(); + }); + + test("memory planned", () => { + const module = et.Module.load("add_mul.pte"); + const methodMeta = module.getMethodMeta("forward"); + methodMeta.inputTensorMeta.forEach((tensorInfo) => { + expect(tensorInfo.isMemoryPlanned).toBe(true); + }); + module.delete(); + }); + + test("nbytes is 16", () => { + const module = et.Module.load("add_mul.pte"); + const methodMeta = module.getMethodMeta("forward"); + methodMeta.inputTensorMeta.forEach((tensorInfo) => { + expect(tensorInfo.nbytes).toEqual(16); + }); + module.delete(); + }); + }); + }); + + describe("execute", () => { + test("add normally", () => { + const module = et.Module.load("add.pte"); + const inputs = [et.Tensor.ones([1]), et.Tensor.ones([1])]; + const output = module.execute("forward", inputs); + + expect(output.length).toEqual(1); + expect(output[0].data).toEqual(new Float32Array([2])); + expect(output[0].sizes).toEqual([1]); + + inputs.forEach((input) => input.delete()); + output.forEach((output) => output.delete()); + module.delete(); + }); + + test("add_mul normally", () => { + const module = et.Module.load("add_mul.pte"); + const inputs = [et.Tensor.ones([2, 2]), et.Tensor.ones([2, 2]), et.Tensor.ones([2, 2])]; + const output = module.execute("forward", inputs); + + expect(output.length).toEqual(1); + expect(output[0].data).toEqual(new Float32Array([3, 3, 3, 3])); + expect(output[0].sizes).toEqual([2, 2]); + + inputs.forEach((input) => input.delete()); + output.forEach((output) => output.delete()); + module.delete(); + }); + + test("forward directly", () => { + const module = et.Module.load("add_mul.pte"); + const inputs = [et.Tensor.ones([2, 2]), et.Tensor.ones([2, 2]), et.Tensor.ones([2, 2])]; + const output = module.forward(inputs); + + expect(output.length).toEqual(1); + expect(output[0].data).toEqual(new Float32Array([3, 3, 3, 3])); + expect(output[0].sizes).toEqual([2, 2]); + + inputs.forEach((input) => input.delete()); + output.forEach((output) => output.delete()); + module.delete(); + }); + + test("wrong number of inputs", () => { + const module = et.Module.load("add_mul.pte"); + const inputs = [et.Tensor.ones([2, 2]), et.Tensor.ones([2, 2])]; + expect(() => module.execute("forward", inputs)).toThrow(); + + inputs.forEach((input) => input.delete()); + module.delete(); + }); + + test("wrong input size", () => { + const module = et.Module.load("add.pte"); + const inputs = [et.Tensor.ones([2, 1]), et.Tensor.ones([2, 1])]; + expect(() => module.execute("forward", inputs)).toThrow(); + + inputs.forEach((input) => input.delete()); + module.delete(); + }); + + test("wrong input type", () => { + const module = et.Module.load("add.pte"); + const inputs = [et.Tensor.ones([1]), et.Tensor.ones([1], et.ScalarType.Long)]; + expect(() => module.execute("forward", inputs)).toThrow(); + + inputs.forEach((input) => input.delete()); + module.delete(); + }); + + test("method does not exist", () => { + const module = et.Module.load("add.pte"); + const inputs = [et.Tensor.ones([1]), et.Tensor.ones([1])]; + expect(() => module.execute("does_not_exist", inputs)).toThrow(); + + inputs.forEach((input) => input.delete()); + module.delete(); + }); + + test("output tensor can be reused", () => { + const module = et.Module.load("add_mul.pte"); + const inputs = [et.Tensor.ones([2, 2]), et.Tensor.ones([2, 2]), et.Tensor.ones([2, 2])]; + const output = module.forward(inputs); + + expect(output.length).toEqual(1); + expect(output[0].data).toEqual(new Float32Array([3, 3, 3, 3])); + expect(output[0].sizes).toEqual([2, 2]); + + const inputs2 = [output[0], output[0], output[0]]; + const output2 = module.forward(inputs2); + + expect(output2.length).toEqual(1); + expect(output2[0].data).toEqual(new Float32Array([21, 21, 21, 21])); + expect(output2[0].sizes).toEqual([2, 2]); + + inputs.forEach((input) => input.delete()); + output.forEach((output) => output.delete()); + output2.forEach((output) => output.delete()); + module.delete(); + }); + }); +}); + +describe("sanity", () => { + // Emscripten enums are equal by default for some reason. + test("different enums are not equal", () => { + expect(et.ScalarType.Float).not.toEqual(et.ScalarType.Long); + expect(et.Tag.Int).not.toEqual(et.Tag.Double); + }); +}); diff --git a/extension/wasm/wasm_bindings.cpp b/extension/wasm/wasm_bindings.cpp new file mode 100644 index 00000000000..6ba41236868 --- /dev/null +++ b/extension/wasm/wasm_bindings.cpp @@ -0,0 +1,713 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include + +#define THROW_JS_ERROR(errorType, message, ...) \ + ({ \ + char msg_buf[256]; \ + int len = snprintf(msg_buf, sizeof(msg_buf), message, ##__VA_ARGS__); \ + if (len < sizeof(msg_buf)) { \ + EM_ASM(throw new errorType(UTF8ToString($0)), msg_buf); \ + } else { \ + std::string msg; \ + msg.resize(len); \ + snprintf(&msg[0], len + 1, message, ##__VA_ARGS__); \ + EM_ASM(throw new errorType(UTF8ToString($0)), msg.c_str()); \ + } \ + __builtin_unreachable(); \ + }) + +/// Throws a JavaScript Error with the provided message if `error` is not `Ok`. +#define THROW_IF_ERROR(error, message, ...) \ + ({ \ + if ET_UNLIKELY ((error) != Error::Ok) { \ + THROW_JS_ERROR(Error, message, ##__VA_ARGS__); \ + } \ + }) + +/// Throws a JavaScript Error with the provided message if `cond` is not `true`. +#define THROW_IF_FALSE(cond, message, ...) \ + ({ \ + if ET_UNLIKELY (!(cond)) { \ + THROW_JS_ERROR(Error, message, ##__VA_ARGS__); \ + } \ + }) + +using namespace emscripten; +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using ::executorch::extension::BufferDataLoader; +using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; +using ::executorch::runtime::Result; +using ::executorch::runtime::Tag; +using ::executorch::runtime::TensorInfo; + +namespace executorch { +namespace extension { +namespace wasm { + +namespace { + +// val represents all JS values. Using val_array to specify that we specifically +// want an array. +template +using val_array = val; + +template +inline void js_array_push(val_array& array, const T& value) { + array.call("push", value); +} + +#define JS_FORALL_SUPPORTED_TENSOR_TYPES(_) \ + _(float, Float) \ + _(int64_t, Long) + +inline ssize_t compute_expected_numel( + const std::vector& sizes) { + return executorch::aten::compute_numel(sizes.data(), sizes.size()); +} + +template +inline void assert_valid_numel( + const std::vector& data, + const std::vector& sizes) { + auto computed_numel = compute_expected_numel(sizes); + THROW_IF_FALSE( + data.size() >= computed_numel, + "Required %ld elements, given %ld", + computed_numel, + data.size()); +} + +constexpr size_t MAX_ELEMENTS = 8 * 1024 * 1024; + +template +std::vector convertJSGeneratorToNumberVector(val generator) { + std::vector data; + while (true) { + val next = generator.call("next"); + if (next["done"].as()) { + break; + } + data.push_back(next["value"].as()); + if (data.size() >= MAX_ELEMENTS) { + THROW_JS_ERROR( + RangeError, + "Generator exceeded maximum element count of %zu", + MAX_ELEMENTS); + } + } + return data; +} + +// make_tensor_ptr() assertions will abort the program if they fail. +// These checks will throw a JS error instead. +void assert_dim_order_and_strides_valid( + const std::vector& sizes, + std::vector& dim_order, + std::vector& strides) { + THROW_IF_FALSE( + dim_order.size() == 0 || dim_order.size() == sizes.size(), + "dim_order size must match sizes or be empty."); + THROW_IF_FALSE( + strides.size() == 0 || strides.size() == sizes.size(), + "strides size must match sizes or be empty."); + + if (dim_order.empty()) { + dim_order.resize(sizes.size()); + std::iota(dim_order.begin(), dim_order.end(), 0); + if (!strides.empty()) { + std::sort(dim_order.begin(), dim_order.end(), [&](size_t a, size_t b) { + return strides[a] > strides[b]; + }); + } + } + std::vector computed_strides(sizes.size()); + + auto error = runtime::dim_order_to_stride( + sizes.data(), dim_order.data(), sizes.size(), computed_strides.data()); + THROW_IF_ERROR(error, "Failed to compute strides."); + + if (!strides.empty()) { + for (size_t i = 0; i < sizes.size(); i++) { + THROW_IF_FALSE( + strides[i] == computed_strides[i] || sizes[i] == 1, + "invalid strides for dim %zu: %" ET_PRI_SIZES_AND_STRIDES + "!= %" ET_PRI_SIZES_AND_STRIDES + " while its size is %" ET_PRI_SIZES_AND_STRIDES " != 1", + i, + strides[i], + computed_strides[i], + sizes[i]); + } + } + + strides = std::move(computed_strides); +} + +/** + * EXPERIMENTAL: JavaScript wrapper for ExecuTorch Tensor. + */ +class ET_EXPERIMENTAL JsTensor { + public: + JsTensor() = delete; + JsTensor(const JsTensor&) = delete; + JsTensor& operator=(const JsTensor&) = delete; + JsTensor(JsTensor&&) = default; + JsTensor& operator=(JsTensor&&) = default; + + explicit JsTensor(TensorPtr tensor) : tensor_(std::move(tensor)) {} + explicit JsTensor(Tensor&& tensor) + : tensor_(std::make_shared(tensor)) {} + + const Tensor& get_tensor() const { + THROW_IF_FALSE(tensor_, "Tensor is null"); + return *tensor_; + } + + ScalarType get_scalar_type() const { + THROW_IF_FALSE(tensor_, "Tensor is null"); + return tensor_->scalar_type(); + } + val get_data() const { + switch (get_scalar_type()) { +#define JS_CASE_TENSOR_TO_VAL_TYPE(T, NAME) \ + case ScalarType::NAME: \ + THROW_IF_FALSE(tensor_->data_ptr(), "Tensor data is null"); \ + return val(typed_memory_view(tensor_->numel(), tensor_->data_ptr())); + JS_FORALL_SUPPORTED_TENSOR_TYPES(JS_CASE_TENSOR_TO_VAL_TYPE) + default: + THROW_JS_ERROR( + TypeError, "Unsupported Tensor type: %d", get_scalar_type()); + } + } + val_array get_sizes() const { + return val::array(get_tensor().sizes().begin(), get_tensor().sizes().end()); + } + + static std::unique_ptr full(val_array sizes, val fill_value) { + // If type is unspecified, infer the type from the fill value. + // Assume it is a Bigint if not Number. + return full( + sizes, + fill_value, + fill_value.isNumber() ? ScalarType::Float : ScalarType::Long); + } + + static std::unique_ptr + full(val_array sizes, val fill_value, ScalarType type) { + auto sizes_vec = + convertJSArrayToNumberVector(sizes); + switch (type) { +#define JS_CASE_FULL_VECTOR_TYPE(T, NAME) \ + case ScalarType::NAME: { \ + TensorPtr tensor = \ + extension::full(sizes_vec, fill_value.as(), ScalarType::NAME); \ + return std::make_unique(std::move(tensor)); \ + } + JS_FORALL_SUPPORTED_TENSOR_TYPES(JS_CASE_FULL_VECTOR_TYPE) + default: + THROW_JS_ERROR(TypeError, "Unsupported Tensor type: %d", type); + } + } + + static std::unique_ptr zeros(val_array sizes) { + return zeros(sizes, ScalarType::Float); + } + + static std::unique_ptr zeros( + val_array sizes, + ScalarType type) { + auto sizes_vec = + convertJSArrayToNumberVector(sizes); + TensorPtr tensor = extension::zeros(sizes_vec, type); + return std::make_unique(std::move(tensor)); + } + + static std::unique_ptr ones(val_array sizes) { + return ones(sizes, ScalarType::Float); + } + + static std::unique_ptr ones(val_array sizes, ScalarType type) { + auto sizes_vec = + convertJSArrayToNumberVector(sizes); + TensorPtr tensor = extension::ones(sizes_vec, type); + return std::make_unique(std::move(tensor)); + } + + static std::unique_ptr from_array( + val_array sizes, + val_array data) { + // If type is unspecified, infer the type from the data. + // Assume it is a Bigint if not Number. + return from_array( + sizes, + data, + data["length"].as() == 0 || data[0].isNumber() + ? ScalarType::Float + : ScalarType::Long); + } + + static std::unique_ptr + from_array(val_array sizes, val_array data, ScalarType type) { + return from_array(sizes, data, type, val::array()); + } + + static std::unique_ptr from_array( + val_array sizes, + val_array data, + ScalarType type, + val_array dim_order) { + return from_array(sizes, data, type, dim_order, val::array()); + } + + static std::unique_ptr from_array( + val_array sizes, + val_array data, + ScalarType type, + val_array dim_order, + val_array strides) { + auto sizes_vec = + convertJSArrayToNumberVector(sizes); + + auto dim_order_vec = + convertJSArrayToNumberVector(dim_order); + auto strides_vec = + convertJSArrayToNumberVector(strides); + + assert_dim_order_and_strides_valid(sizes_vec, dim_order_vec, strides_vec); + switch (type) { +#define JS_CASE_FROM_ARRAY_VECTOR_TYPE(T, NAME) \ + case ScalarType::NAME: { \ + auto data_vec = convertJSArrayToNumberVector(data); \ + assert_valid_numel(data_vec, sizes_vec); \ + TensorPtr tensor = make_tensor_ptr( \ + std::move(sizes_vec), \ + std::move(data_vec), \ + std::move(dim_order_vec), \ + std::move(strides_vec), \ + ScalarType::NAME); \ + return std::make_unique(std::move(tensor)); \ + } + JS_FORALL_SUPPORTED_TENSOR_TYPES(JS_CASE_FROM_ARRAY_VECTOR_TYPE) + default: + THROW_JS_ERROR(TypeError, "Unsupported Tensor type: %d", type); + } + } + + static std::unique_ptr from_iter( + val_array sizes, + val_array data) { + return from_iter(sizes, data, ScalarType::Float); + } + + static std::unique_ptr + from_iter(val_array sizes, val_array data, ScalarType type) { + return from_iter(sizes, data, type, val::array()); + } + + static std::unique_ptr from_iter( + val_array sizes, + val_array data, + ScalarType type, + val_array dim_order) { + return from_iter(sizes, data, type, dim_order, val::array()); + } + + static std::unique_ptr from_iter( + val_array sizes, + val_array data, + ScalarType type, + val_array dim_order, + val_array strides) { + auto sizes_vec = + convertJSArrayToNumberVector(sizes); + auto dim_order_vec = + convertJSArrayToNumberVector(dim_order); + auto strides_vec = + convertJSArrayToNumberVector(strides); + + assert_dim_order_and_strides_valid(sizes_vec, dim_order_vec, strides_vec); + + switch (type) { +#define JS_CASE_FROM_ITER_VECTOR_TYPE(T, NAME) \ + case ScalarType::NAME: { \ + auto data_vec = convertJSGeneratorToNumberVector(data); \ + assert_valid_numel(data_vec, sizes_vec); \ + TensorPtr tensor = make_tensor_ptr( \ + std::move(sizes_vec), \ + std::move(data_vec), \ + std::move(dim_order_vec), \ + std::move(strides_vec), \ + ScalarType::NAME); \ + return std::make_unique(std::move(tensor)); \ + } + JS_FORALL_SUPPORTED_TENSOR_TYPES(JS_CASE_FROM_ITER_VECTOR_TYPE) + default: + THROW_JS_ERROR(TypeError, "Unsupported Tensor type: %d", type); + } + } + + private: + TensorPtr tensor_; +}; + +// Converts JS value to EValue. +EValue to_evalue(val v) { + if (v.isUndefined()) { + THROW_JS_ERROR(TypeError, "Value cannot be undefined"); + } + if (v.isNull()) { + return EValue(); + } else if (v.isNumber()) { + return EValue(v.as()); + } else if (v.isTrue()) { + return EValue(true); + } else if (v.isFalse()) { + return EValue(false); + } else { + const std::string& type_str = v.typeOf().as(); + if (type_str == "bigint") { + return EValue(v.as()); + } else if (type_str == "object") { + // If it is an object, assume it is a tensor. + THROW_IF_FALSE( + v.instanceof + (val::module_property("Tensor")), + "Received non-tensor object: %s", + val::global("JSON").call("stringify", v).c_str()); + return EValue(v.as().get_tensor()); + } + THROW_JS_ERROR( + TypeError, "Unsupported JavaScript type: %s", type_str.c_str()); + } +} + +// Converts EValue to JS value. +val to_val(EValue&& v) { + if (v.isNone()) { + return val::null(); + } else if (v.isInt()) { + return val(v.toInt()); + } else if (v.isDouble()) { + return val(v.toDouble()); + } else if (v.isBool()) { + return val(v.toBool()); + } else if (v.isTensor()) { + Tensor tensor = std::move(v).toTensor(); + std::unique_ptr wrapper = + std::make_unique(std::move(tensor)); + return val(std::move(wrapper)); + } else { + char tag_buf[32]; + runtime::tag_to_string(v.tag, tag_buf, sizeof(tag_buf)); + THROW_JS_ERROR(TypeError, "Unsupported EValue type: %s", tag_buf); + } +} + +/** + * EXPERIMENTAL: JavaScript object containing tensor metadata. + */ +struct ET_EXPERIMENTAL JsTensorInfo { + val_array sizes; + val_array dim_order; + ScalarType scalar_type; + bool is_memory_planned; + size_t nbytes; + std::string name; + + static JsTensorInfo from_tensor_info(const TensorInfo& info) { + return { + val::array(info.sizes().begin(), info.sizes().end()), + val::array(info.dim_order().begin(), info.dim_order().end()), + info.scalar_type(), + info.is_memory_planned(), + info.nbytes(), + std::string(info.name())}; + } +}; + +/** + * EXPERIMENTAL: JavaScript object containing method metadata. + */ +struct ET_EXPERIMENTAL JsMethodMeta { + std::string name; + val_array input_tags; + val_array input_tensor_meta; + val_array output_tags; + val_array output_tensor_meta; + val_array attribute_tensor_meta; + val_array memory_planned_buffer_sizes; + val_array backends; + ET_DEPRECATED size_t num_instructions; + + static JsMethodMeta from_method_meta(const MethodMeta& meta) { + JsMethodMeta new_meta{ + meta.name(), + val::array(), + val::array(), + val::array(), + val::array(), + val::array(), + val::array(), + val::array(), + meta.num_instructions()}; + for (int i = 0; i < meta.num_inputs(); i++) { + js_array_push(new_meta.input_tags, meta.input_tag(i).get()); + js_array_push( + new_meta.input_tensor_meta, + JsTensorInfo::from_tensor_info(meta.input_tensor_meta(i).get())); + } + for (int i = 0; i < meta.num_outputs(); i++) { + js_array_push(new_meta.output_tags, meta.output_tag(i).get()); + js_array_push( + new_meta.output_tensor_meta, + JsTensorInfo::from_tensor_info(meta.output_tensor_meta(i).get())); + } + for (int i = 0; i < meta.num_attributes(); i++) { + js_array_push( + new_meta.attribute_tensor_meta, + JsTensorInfo::from_tensor_info(meta.attribute_tensor_meta(i).get())); + } + for (int i = 0; i < meta.num_memory_planned_buffers(); i++) { + js_array_push( + new_meta.memory_planned_buffer_sizes, + meta.memory_planned_buffer_size(i).get()); + } + for (int i = 0; i < meta.num_backends(); i++) { + js_array_push( + new_meta.backends, val::u8string(meta.get_backend_name(i).get())); + } + return new_meta; + } +}; + +/** + * EXPERIMENTAL: Wrapper around extension/Module for JavaScript. + */ +class ET_EXPERIMENTAL JsModule final { + public: + JsModule() = delete; + JsModule(const JsModule&) = delete; + JsModule& operator=(const JsModule&) = delete; + JsModule(JsModule&&) = default; + JsModule& operator=(JsModule&&) = default; + + explicit JsModule(std::unique_ptr module) + : buffer_(0), module_(std::move(module)) {} + + explicit JsModule(std::vector buffer, std::unique_ptr module) + : buffer_(std::move(buffer)), module_(std::move(module)) {} + + static std::unique_ptr load_from_uint8_array(val data) { + size_t length = data["length"].as(); + std::vector buffer(length); + val memory_view = val(typed_memory_view(length, buffer.data())); + memory_view.call("set", data); + auto loader = std::make_unique(buffer.data(), length); + return std::make_unique( + std::move(buffer), std::make_unique(std::move(loader))); + } + + static std::unique_ptr load(val data) { + if (data.isNull() || data.isUndefined()) { + THROW_JS_ERROR(TypeError, "Data cannot be null or undefined"); + } + if (data.isString()) { + return std::make_unique( + std::make_unique(data.as())); + } else if (data.instanceof (val::global("Uint8Array"))) { + return load_from_uint8_array(data); + } else if (data.instanceof (val::global("ArrayBuffer"))) { + return load_from_uint8_array(val::global("Uint8Array").new_(data)); + } else { + THROW_JS_ERROR( + TypeError, + "Unsupported data type: %s", + data.typeOf().as().c_str()); + } + } + + val get_methods() { + auto res = module_->method_names(); + THROW_IF_ERROR( + res.error(), + "Failed to get methods, error: 0x%" PRIx32, + static_cast(res.error())); + return val::array(res.get().begin(), res.get().end()); + } + + void load_method(const std::string& method_name) { + Error res = module_->load_method(method_name); + THROW_IF_ERROR( + res, + "Failed to load method %s, error: 0x%" PRIx32, + method_name.c_str(), + static_cast(res)); + } + + JsMethodMeta get_method_meta(const std::string& method_name) { + auto res = module_->method_meta(method_name); + THROW_IF_ERROR( + res.error(), + "Failed to get method meta for %s, error: 0x%" PRIx32, + method_name.c_str(), + static_cast(res.error())); + return JsMethodMeta::from_method_meta(res.get()); + } + + val_array execute(const std::string& method, val js_inputs) { + std::vector inputs; + if (js_inputs.isArray()) { + inputs.reserve(js_inputs["length"].as()); + for (val v : js_inputs) { + inputs.push_back(to_evalue(v)); + } + } else { + inputs.push_back(to_evalue(js_inputs)); + } + auto res = module_->execute(method, inputs); + THROW_IF_ERROR( + res.error(), + "Failed to execute method %s, error: 0x%" PRIx32, + method.c_str(), + static_cast(res.error())); + std::vector outputs = res.get(); + val js_outputs = val::array(); + for (auto& output : outputs) { + js_array_push(js_outputs, to_val(std::move(output))); + } + return js_outputs; + } + + val_array forward(val inputs) { + return execute("forward", inputs); + } + + private: + // If loaded from a buffer, keeps it alive for the lifetime of the module. + std::vector buffer_; + std::unique_ptr module_; +}; + +} // namespace + +EMSCRIPTEN_BINDINGS(WasmBindings) { + enum_("ScalarType") +#define JS_DECLARE_SCALAR_TYPE(T, NAME) .value(#NAME, ScalarType::NAME) + JS_FORALL_SUPPORTED_TENSOR_TYPES(JS_DECLARE_SCALAR_TYPE); + enum_("Tag") +#define JS_DECLARE_TAG(NAME) .value(#NAME, Tag::NAME) + EXECUTORCH_FORALL_TAGS(JS_DECLARE_TAG); + + class_("Module") + .class_function("load", &JsModule::load) + .function("getMethods", &JsModule::get_methods) + .function("loadMethod", &JsModule::load_method) + .function("getMethodMeta", &JsModule::get_method_meta) + .function("execute", &JsModule::execute) + .function("forward", &JsModule::forward); + class_("Tensor") + .class_function( + "zeros", + select_overload(val)>(&JsTensor::zeros)) + .class_function( + "zeros", + select_overload(val, ScalarType)>( + &JsTensor::zeros)) + .class_function( + "ones", + select_overload(val)>(&JsTensor::ones)) + .class_function( + "ones", + select_overload(val, ScalarType)>( + &JsTensor::ones)) + .class_function( + "full", + select_overload(val, val)>(&JsTensor::full)) + .class_function( + "full", + select_overload(val, val, ScalarType)>( + &JsTensor::full)) + .class_function( + "fromArray", + select_overload(val, val)>( + &JsTensor::from_array)) + .class_function( + "fromArray", + select_overload(val, val, ScalarType)>( + &JsTensor::from_array)) + .class_function( + "fromArray", + select_overload(val, val, ScalarType, val)>( + &JsTensor::from_array)) + .class_function( + "fromArray", + select_overload( + val, val, ScalarType, val, val)>(&JsTensor::from_array)) + .class_function( + "fromIter", + select_overload(val, val)>( + &JsTensor::from_iter)) + .class_function( + "fromIter", + select_overload(val, val, ScalarType)>( + &JsTensor::from_iter)) + .class_function( + "fromIter", + select_overload(val, val, ScalarType, val)>( + &JsTensor::from_iter)) + .class_function( + "fromIter", + select_overload( + val, val, ScalarType, val, val)>(&JsTensor::from_iter)) + .property("scalarType", &JsTensor::get_scalar_type) + .property("data", &JsTensor::get_data) + .property("sizes", &JsTensor::get_sizes); + value_object("TensorInfo") + .field("sizes", &JsTensorInfo::sizes) + .field("dimOrder", &JsTensorInfo::dim_order) + .field("scalarType", &JsTensorInfo::scalar_type) + .field("isMemoryPlanned", &JsTensorInfo::is_memory_planned) + .field("nbytes", &JsTensorInfo::nbytes) + .field("name", &JsTensorInfo::name); + value_object("MethodMeta") + .field("name", &JsMethodMeta::name) + .field("inputTags", &JsMethodMeta::input_tags) + .field("inputTensorMeta", &JsMethodMeta::input_tensor_meta) + .field("outputTags", &JsMethodMeta::output_tags) + .field("outputTensorMeta", &JsMethodMeta::output_tensor_meta) + .field("attributeTensorMeta", &JsMethodMeta::attribute_tensor_meta) + .field( + "memoryPlannedBufferSizes", + &JsMethodMeta::memory_planned_buffer_sizes) + .field("backends", &JsMethodMeta::backends) + .field("numInstructions", &JsMethodMeta::num_instructions); + +// For some reason Embind doesn't make it easy to get the names of enums. +// Additionally, different enums of the same type are considered to be equal. +// Assigning the name field fixes both of these issues. +#define JS_ASSIGN_SCALAR_TYPE_NAME(T, NAME) \ + val::module_property("ScalarType")[#NAME].set("name", #NAME); + JS_FORALL_SUPPORTED_TENSOR_TYPES(JS_ASSIGN_SCALAR_TYPE_NAME) +#define JS_ASSIGN_TAG_NAME(NAME) \ + val::module_property("Tag")[#NAME].set("name", #NAME); + EXECUTORCH_FORALL_TAGS(JS_ASSIGN_TAG_NAME) +} + +} // namespace wasm +} // namespace extension +} // namespace executorch diff --git a/scripts/build_wasm_tests.sh b/scripts/build_wasm_tests.sh new file mode 100644 index 00000000000..0a6b6f0b243 --- /dev/null +++ b/scripts/build_wasm_tests.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +CMAKE_OUT=cmake-out-wasm + +cd "$(dirname "${BASH_SOURCE[0]}")/../" +emcmake cmake . -DEXECUTORCH_BUILD_WASM=ON \ + -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ + -DEXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR=ON \ + -DEXECUTORCH_BUILD_DEVTOOLS=ON \ + -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ + -DEXECUTORCH_BUILD_WASM_TESTS=ON \ + -DCMAKE_BUILD_TYPE=Release \ + -B"${CMAKE_OUT}" + +if [ "$(uname)" == "Darwin" ]; then + CMAKE_JOBS=$(( $(sysctl -n hw.ncpu) - 1 )) +else + CMAKE_JOBS=$(( $(nproc) - 1 )) +fi + +cmake --build ${CMAKE_OUT} --target executorch_wasm_tests -j ${CMAKE_JOBS}