diff --git a/extension/module/module.cpp b/extension/module/module.cpp index 3212077d2ee..e6ead0014c2 100644 --- a/extension/module/module.cpp +++ b/extension/module/module.cpp @@ -309,6 +309,13 @@ runtime::Error Module::set_output( output_tensor.mutable_data_ptr(), output_tensor.nbytes(), output_index); } -} // namespace ET_MODULE_NAMESPACE +runtime::Error Module::update( + const std::string& method_name, + runtime::ArrayRef backend_options) { + ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name)); + auto& method = methods_.at(method_name).method; + return method->update(backend_options); +} + } // namespace extension } // namespace executorch diff --git a/extension/module/module.h b/extension/module/module.h index 080ae53f43a..6002ddd7c63 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -466,6 +466,32 @@ class Module { return set_output("forward", std::move(output_value), output_index); } + /** + * EXPERIMENTAL: Updates backend options for a specific method. + * Loads the program and method before updating if needed. + * + * @param[in] method_name The name of the method to update. + * @param[in] backend_options The backend options to update the method with. + * + * @returns An Error to indicate success or failure. + */ + ET_EXPERIMENTAL ET_NODISCARD runtime::Error update( + const std::string& method_name, + runtime::ArrayRef backend_options); + + /** + * EXPERIMENTAL: Updates backend options for the 'forward' method. + * Loads the program and method before updating if needed. + * + * @param[in] backend_options The backend options to update the method with. + * + * @returns An Error to indicate success or failure. + */ + ET_EXPERIMENTAL ET_NODISCARD inline runtime::Error update( + runtime::ArrayRef backend_options) { + return update("forward", backend_options); + } + /** * Retrieves the EventTracer instance being used by the Module. * EventTracer is used for tracking and logging events during the execution diff --git a/extension/module/test/module_test.cpp b/extension/module/test/module_test.cpp index e0444c2aefb..24476c4adab 100644 --- a/extension/module/test/module_test.cpp +++ b/extension/module/test/module_test.cpp @@ -16,9 +16,15 @@ #include #include #include +#include +#include +#include using namespace ::executorch::extension; using namespace ::executorch::runtime; +using executorch::runtime::BackendOptions; +using executorch::runtime::Entry; +using executorch::runtime::IntKey; class ModuleTest : public ::testing::Test { protected: @@ -26,11 +32,16 @@ class ModuleTest : public ::testing::Test { model_path_ = std::getenv("ET_MODULE_ADD_PATH"); add_mul_path_ = std::getenv("ET_MODULE_ADD_MUL_PROGRAM_PATH"); add_mul_data_path_ = std::getenv("ET_MODULE_ADD_MUL_DATA_PATH"); + stub_model_path_ = std::getenv("ET_MODULE_ADD_MUL_DELEGATED_PATH"); + + // Register the StubBackend for testing + StubBackend::register_singleton(); } static inline std::string model_path_; static inline std::string add_mul_path_; static inline std::string add_mul_data_path_; + static inline std::string stub_model_path_; }; TEST_F(ModuleTest, TestLoad) { @@ -466,3 +477,34 @@ TEST_F(ModuleTest, TestPTD) { auto tensor = make_tensor_ptr({2, 2}, {2.f, 3.f, 4.f, 2.f}); ASSERT_EQ(module.forward(tensor).error(), Error::Ok); } + +TEST_F(ModuleTest, TestUpdate) { + Module module(stub_model_path_); + + BackendOptionsMap<3> map; + BackendOptions<1> backend_options; + int new_num_threads = 4; + backend_options.set_option(IntKey("NumberOfThreads"), new_num_threads); + map.add("StubBackend", backend_options.view()); + + // Test update method with specific method name + const auto update_result = module.update("forward", map.entries()); + EXPECT_EQ(update_result, Error::Ok); + + ASSERT_EQ(StubBackend::singleton().num_threads(), new_num_threads); + +} + +TEST_F(ModuleTest, TestUpdateNonExistentMethod) { + Module module(stub_model_path_); + + BackendOptionsMap<3> map; + BackendOptions<1> backend_options; + int new_num_threads = 4; + backend_options.set_option(IntKey("NumberOfThreads"), new_num_threads); + map.add("StubBackend", backend_options.view()); + + // Test update method with non-existent method name + const auto update_result = module.update("nonexistent", map.entries()); + EXPECT_NE(update_result, Error::Ok); +} diff --git a/extension/module/test/targets.bzl b/extension/module/test/targets.bzl index e09b43e356d..0c4e8914dc8 100644 --- a/extension/module/test/targets.bzl +++ b/extension/module/test/targets.bzl @@ -19,6 +19,7 @@ def define_common_targets(is_fbcode=False): "ET_MODULE_ADD_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleAdd.pte])", "ET_MODULE_ADD_MUL_PROGRAM_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleAddMul.pte])", "ET_MODULE_ADD_MUL_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleAddMul.ptd])", + "ET_MODULE_ADD_MUL_DELEGATED_PATH": "$(location fbcode//executorch/test/models:exported_delegated_add_mul[ModuleAddMul.pte])", } for aten_mode in get_aten_mode_options(): @@ -35,6 +36,7 @@ def define_common_targets(is_fbcode=False): "//executorch/extension/module:module" + aten_suffix, "//executorch/extension/tensor:tensor" + aten_suffix, "//executorch/runtime/core/exec_aten/testing_util:tensor_util" + aten_suffix, + "//executorch/runtime/executor/test:stub_backend", ], env = modules_env, platforms = [CXX, ANDROID], # Cannot bundle resources on Apple platform. diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index 4f1b6dd4b26..d9f294af9ff 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -1513,7 +1513,8 @@ Error Method::experimental_step() { return step(); } -Error Method::update(executorch::runtime::ArrayRef backend_option) { +Error Method::update( + executorch::runtime::ArrayRef backend_option) { for (const auto& entry : backend_option) { const char* backend_name = entry.backend_name; auto backend_options = entry.options; diff --git a/runtime/executor/method.h b/runtime/executor/method.h index 4564615be11..b7267e316d2 100644 --- a/runtime/executor/method.h +++ b/runtime/executor/method.h @@ -241,13 +241,14 @@ class Method final { /// DEPRECATED: Use `reset_execution()` instead. ET_DEPRECATED ET_NODISCARD Error experimental_reset_execution(); - /** + /** * EXPERIMENTAL: Update backend options, which will be dispatched to different backends. * * @retval Error::Ok step succeeded * @retval non-Ok Method update fails */ - ET_EXPERIMENTAL ET_NODISCARD Error update(executorch::runtime::ArrayRef backend_option); + ET_EXPERIMENTAL ET_NODISCARD Error update( + executorch::runtime::ArrayRef backend_option); /** * Returns the MethodMeta that corresponds to the calling Method. diff --git a/runtime/executor/test/method_update_test.cpp b/runtime/executor/test/method_update_test.cpp index 11c6f281953..a9c1d77eaf2 100644 --- a/runtime/executor/test/method_update_test.cpp +++ b/runtime/executor/test/method_update_test.cpp @@ -6,182 +6,103 @@ * LICENSE file in the root directory of this source tree. */ - #include - #include - - #include - #include - #include - #include - #include - #include - #include - #include - #include - #include - #include - #include - #include - #include +#include +#include - - using namespace ::testing; - using executorch::aten::ArrayRef; - using executorch::runtime::Error; - using executorch::runtime::EValue; - using executorch::runtime::Method; - using executorch::runtime::Program; - using executorch::runtime::Result; - using executorch::runtime::testing::ManagedMemoryManager; - using torch::executor::util::FileDataLoader; - using executorch::runtime::BackendExecutionContext; +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace ::testing; +using executorch::aten::ArrayRef; +using executorch::runtime::BackendExecutionContext; using executorch::runtime::BackendInitContext; using executorch::runtime::BackendInterface; -using executorch::runtime::BackendUpdateContext; using executorch::runtime::BackendOption; using executorch::runtime::BackendOptions; using executorch::runtime::BackendOptionsMap; +using executorch::runtime::BackendUpdateContext; using executorch::runtime::BoolKey; -using executorch::runtime::IntKey; -using executorch::runtime::Entry; using executorch::runtime::CompileSpec; using executorch::runtime::DataLoader; using executorch::runtime::DelegateHandle; +using executorch::runtime::Entry; +using executorch::runtime::Error; +using executorch::runtime::EValue; using executorch::runtime::FreeableBuffer; +using executorch::runtime::IntKey; +using executorch::runtime::Method; +using executorch::runtime::Program; +using executorch::runtime::Result; +using executorch::runtime::testing::ManagedMemoryManager; +using torch::executor::util::FileDataLoader; - constexpr size_t kDefaultNonConstMemBytes = 32 * 1024U; - constexpr size_t kDefaultRuntimeMemBytes = 32 * 1024U; - -/** - * A backend class whose methods can be overridden individually. - */ -class StubBackend final : public BackendInterface { - public: - - // Default name that this backend is registered as. - static constexpr char kName[] = "StubBackend"; - - bool is_available() const override { - return true; - } - - Result init( - BackendInitContext& context, - FreeableBuffer* processed, - ArrayRef compile_specs) const override { - return nullptr; - } - - Error execute( - BackendExecutionContext& context, - DelegateHandle* handle, - EValue** args) const override { - return Error::Ok; - } - - int num_threads() const { - return num_threads_; - } - - Error update( - BackendUpdateContext& context, - const executorch::runtime::ArrayRef& backend_options) const override { - int success_update = 0; - for (const auto& backend_option : backend_options) { - if (strcmp(backend_option.key, "NumberOfThreads") == 0) { - if (std::holds_alternative(backend_option.value)) { - num_threads_ = std::get(backend_option.value); - success_update++; - } - } - } - if (success_update == backend_options.size()) { - return Error::Ok; - } - return Error::InvalidArgument; - } +constexpr size_t kDefaultNonConstMemBytes = 32 * 1024U; +constexpr size_t kDefaultRuntimeMemBytes = 32 * 1024U; - /** - * Registers the singleton instance if not already registered. - * - * Note that this can be used to install the stub as the implementation for - * any export-time backend by passing in the right name, as long as no other - * backend with that name has been registered yet. - */ - static Error register_singleton(const char* name = kName) { - if (!registered_) { - registered_ = true; - return executorch::runtime::register_backend({name, &singleton_}); - } - return Error::Ok; - } - - /** - * Returns the instance that was added to the backend registry. - */ - static StubBackend& singleton() { - return singleton_; - } - - private: - static bool registered_; - static StubBackend singleton_; - mutable int num_threads_ = 1; - }; - - bool StubBackend::registered_ = false; - StubBackend StubBackend::singleton_; - - class MethodUpdateTest : public ::testing::Test { - protected: - void load_program() { +class MethodUpdateTest : public ::testing::Test { + protected: + void load_program() { // Since these tests cause ET_LOG to be called, the PAL must be initialized // first. executorch::runtime::runtime_init(); // Create a loader for the serialized program. - ASSERT_EQ(StubBackend::register_singleton(), Error::Ok); + ASSERT_EQ(StubBackend::register_singleton(), Error::Ok); + + auto loader_res = + FileDataLoader::from(std::getenv("ET_MODULE_ADD_MUL_DELEGATED_PATH")); + ASSERT_EQ(loader_res.error(), Error::Ok); + loader_ = std::make_unique(std::move(loader_res.get())); + + // Use it to load the program. + auto program_res = Program::load(loader_.get()); + ASSERT_EQ(program_res.error(), Error::Ok); + program_ = std::make_unique(std::move(program_res.get())); + } + + void SetUp() override { + executorch::runtime::runtime_init(); + + load_program(); + } - auto loader_res = FileDataLoader::from(std::getenv("ET_MODULE_ADD_MUL_DELEGATED_PATH")); - ASSERT_EQ(loader_res.error(), Error::Ok); - loader_ = std::make_unique(std::move(loader_res.get())); - - // Use it to load the program. - auto program_res = Program::load(loader_.get()); - ASSERT_EQ(program_res.error(), Error::Ok); - program_ = std::make_unique(std::move(program_res.get())); - } + private: + std::unique_ptr loader_; - void SetUp() override { - executorch::runtime::runtime_init(); - - load_program(); - } - - private: - std::unique_ptr loader_; + protected: + std::unique_ptr program_; +}; - protected: - std::unique_ptr program_; - }; - - TEST_F(MethodUpdateTest, MoveTest) { +TEST_F(MethodUpdateTest, MoveTest) { BackendInterface* backend = - executorch::runtime::get_backend_class(StubBackend::kName); + executorch::runtime::get_backend_class(StubBackend::kName); ASSERT_EQ(backend, &StubBackend::singleton()); - ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes); - Result method = program_->load_method("forward", &mmm.get()); + ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes); + Result method = program_->load_method("forward", &mmm.get()); // Check that the default number of threads is 1. - ASSERT_EQ(StubBackend::singleton().num_threads(), 1); - ASSERT_EQ(method.error(), Error::Ok); + ASSERT_EQ(StubBackend::singleton().num_threads(), 1); + ASSERT_EQ(method.error(), Error::Ok); - BackendOptionsMap<3> map; - BackendOptions<1> backend_options; - int new_num_threads = 4; - backend_options.set_option(IntKey("NumberOfThreads"), new_num_threads); - map.add("StubBackend", backend_options.view()); - Error update_result = method->update(map.entries()); - ASSERT_EQ(update_result, Error::Ok); - ASSERT_EQ(StubBackend::singleton().num_threads(), new_num_threads); + BackendOptionsMap<3> map; + BackendOptions<1> backend_options; + int new_num_threads = 4; + backend_options.set_option(IntKey("NumberOfThreads"), new_num_threads); + map.add("StubBackend", backend_options.view()); + Error update_result = method->update(map.entries()); + ASSERT_EQ(update_result, Error::Ok); + ASSERT_EQ(StubBackend::singleton().num_threads(), new_num_threads); } diff --git a/runtime/executor/test/stub_backend.h b/runtime/executor/test/stub_backend.h new file mode 100644 index 00000000000..dfbf2536963 --- /dev/null +++ b/runtime/executor/test/stub_backend.h @@ -0,0 +1,110 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +using executorch::aten::ArrayRef; +using executorch::runtime::BackendExecutionContext; +using executorch::runtime::BackendInitContext; +using executorch::runtime::BackendInterface; +using executorch::runtime::BackendOption; +using executorch::runtime::BackendUpdateContext; +using executorch::runtime::CompileSpec; +using executorch::runtime::DelegateHandle; +using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::Result; + +/** + * A backend class whose methods can be overridden individually. + */ + class StubBackend final : public BackendInterface { + public: + + // Default name that this backend is registered as. + static constexpr char kName[] = "StubBackend"; + + bool is_available() const override { + return true; + } + + Result init( + BackendInitContext& context, + FreeableBuffer* processed, + ArrayRef compile_specs) const override { + return nullptr; + } + + Error execute( + BackendExecutionContext& context, + DelegateHandle* handle, + EValue** args) const override { + return Error::Ok; + } + + int num_threads() const { + return num_threads_; + } + + Error update( + BackendUpdateContext& context, + const executorch::runtime::ArrayRef& backend_options) const override { + int success_update = 0; + for (const auto& backend_option : backend_options) { + if (strcmp(backend_option.key, "NumberOfThreads") == 0) { + if (std::holds_alternative(backend_option.value)) { + num_threads_ = std::get(backend_option.value); + success_update++; + } + } + } + if (success_update == backend_options.size()) { + return Error::Ok; + } + return Error::InvalidArgument; + } + + /** + * Registers the singleton instance if not already registered. + * + * Note that this can be used to install the stub as the implementation for + * any export-time backend by passing in the right name, as long as no other + * backend with that name has been registered yet. + */ + static Error register_singleton(const char* name = kName) { + if (!registered_) { + registered_ = true; + return executorch::runtime::register_backend({name, &singleton_}); + } + return Error::Ok; + } + + /** + * Returns the instance that was added to the backend registry. + */ + static StubBackend& singleton() { + return singleton_; + } + + private: + static bool registered_; + static StubBackend singleton_; + mutable int num_threads_ = 1; + }; + +// Static member definitions +bool StubBackend::registered_ = false; +StubBackend StubBackend::singleton_; diff --git a/runtime/executor/test/targets.bzl b/runtime/executor/test/targets.bzl index b075e5b6b62..e70774d5ce1 100644 --- a/runtime/executor/test/targets.bzl +++ b/runtime/executor/test/targets.bzl @@ -92,6 +92,23 @@ def define_common_targets(is_fbcode = False): ], ) + runtime.cxx_library( + name = "stub_backend", + srcs = [], + exported_headers = [ + "stub_backend.h", + ], + visibility = [ + "//executorch/runtime/executor/test/...", + "//executorch/extension/module/test/...", + "//executorch/test/...", + "@EXECUTORCH_CLIENTS", + ], + exported_deps = [ + "//executorch/runtime/backend:interface", + ], + ) + runtime.cxx_test( name = "pte_data_map_test", srcs = [ @@ -178,6 +195,7 @@ def define_common_targets(is_fbcode = False): ], deps = [ ":managed_memory_manager", + ":stub_backend", "//executorch/runtime/backend:interface", "//executorch/runtime/executor:program", "//executorch/extension/data_loader:buffer_data_loader", @@ -185,7 +203,8 @@ def define_common_targets(is_fbcode = False): ], env = { "ET_MODULE_ADD_MUL_DELEGATED_PATH": "$(location fbcode//executorch/test/models:exported_delegated_add_mul[ModuleAddMul.pte])", - }, ) + }, + ) runtime.cxx_test( name = "program_test", diff --git a/test/models/targets.bzl b/test/models/targets.bzl index d3fa3230468..601d2d7ff5d 100644 --- a/test/models/targets.bzl +++ b/test/models/targets.bzl @@ -187,6 +187,7 @@ def define_common_targets(): visibility = [ "//executorch/runtime/executor/test/...", "//executorch/test/...", + "//executorch/extension/module/test/...", ], )