From ec8e11c96ecf56c8a96da58cc47d22ffb3236ecd Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Mon, 23 Jun 2025 21:20:29 -0700 Subject: [PATCH 1/3] [1/N] Add BackendOptions class Pull Request resolved: https://github.com/pytorch/executorch/pull/11389 Introduce backend option as discussed in https://github.com/pytorch/executorch/discussions/10216 Step 1: Introducd Backend Option class In later stage, it will be plugged in with the rest of the stack. BackendOptions is pretty much a list of BackendOption, and backend option is a key value pair. The key is a string, and the value can be 3 different types, including bool, string and int. ghstack-source-id: 292257885 Differential Revision: [D75993712](https://our.internmc.facebook.com/intern/diff/D75993712/) --- runtime/backend/options.h | 208 ++++++++++++++++++ runtime/backend/targets.bzl | 1 + runtime/backend/test/backend_options_test.cpp | 164 ++++++++++++++ runtime/backend/test/targets.bzl | 12 +- 4 files changed, 384 insertions(+), 1 deletion(-) create mode 100644 runtime/backend/options.h create mode 100644 runtime/backend/test/backend_options_test.cpp diff --git a/runtime/backend/options.h b/runtime/backend/options.h new file mode 100644 index 00000000000..7ff742f4e8f --- /dev/null +++ b/runtime/backend/options.h @@ -0,0 +1,208 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace runtime { + +static constexpr size_t kMaxOptionKeyLength = 64; +static constexpr size_t kMaxOptionValueLength = 256; + +// All string keys and values must have static storage duration (string +// literals, static const char arrays, or global constants). The BackendOptions +// class does NOT take ownership of strings. +using OptionValue = + std::variant>; + +struct BackendOption { + // key is the name of the backend option, like num_threads, enable_profiling, + // etc + char key[kMaxOptionKeyLength]{}; + // value is the value of the backend option, like 4, true, etc + OptionValue value; +}; + +/** + * A template class for storing and managing backend-specific configuration + * options. + * + * This class provides a type-safe way to store key-value pairs for backend + * configuration, with compile-time capacity limits and runtime type checking. + * It supports bool, int, and const char* value types. + * + * @tparam MaxCapacity The maximum number of options that can be stored + */ +template +class BackendOptions { + public: + /** + * Copy constructor + */ + BackendOptions(const BackendOptions& other) : size_(other.size_) { + for (size_t i = 0; i < size_; ++i) { + options_[i] = other.options_[i]; + } + } + + /** + * Copy assignment operator + */ + BackendOptions& operator=(const BackendOptions& other) { + if (this != &other) { + size_ = other.size_; + for (size_t i = 0; i < size_; ++i) { + options_[i] = other.options_[i]; + } + } + return *this; + } + + /** + * Default constructor - initializes with zero options. + */ + BackendOptions() : size_(0) {} + + /** + * Returns a mutable view of all stored options as a Span. + * + * @return A mutable Span containing all BackendOption entries + */ + executorch::runtime::Span view() { + return executorch::runtime::Span(options_, size_); + } + + /** + * Sets a boolean option value for the given key. + * If the key already exists, updates its value. Otherwise, adds a new option. + * + * @tparam N The length of the key string (automatically deduced) + * @param key The option key (must be a string literal or array) + * @param value The boolean value to set + * @return Error::Ok on success, Error::InvalidArgument if storage is full + */ + template + Error set_option(const char (&key)[N], bool value) noexcept { + static_assert(N <= kMaxOptionKeyLength, "Option key is too long"); + return set_option_impl(key, value); + } + + /** + * Sets an integer option value for the given key. + * If the key already exists, updates its value. Otherwise, adds a new option. + * + * @tparam N The length of the key string (automatically deduced) + * @param key The option key (must be a string literal or array) + * @param value The integer value to set + * @return Error::Ok on success, Error::InvalidArgument if storage is full + */ + template + Error set_option(const char (&key)[N], int value) noexcept { + static_assert(N <= kMaxOptionKeyLength, "Option key is too long"); + return set_option_impl(key, value); + } + + /** + * Sets a string option value for the given key. + * If the key already exists, updates its value. Otherwise, adds a new option. + * + * Note: The string value must have static storage duration. This class does + * NOT take ownership of the string - it only stores the pointer. + * + * @tparam N The length of the key string (automatically deduced) + * @param key The option key (must be a string literal or array) + * @param value The string value to set (must have static storage duration) + * @return Error::Ok on success, Error::InvalidArgument if storage is full + */ + template + Error set_option(const char (&key)[N], const char* value) noexcept { + static_assert(N <= kMaxOptionKeyLength, "Option key is too long"); + // Create a fixed-size array and copy the string + std::array arr{}; + strncpy(arr.data(), value, kMaxOptionValueLength - 1); + arr[kMaxOptionValueLength - 1] = '\0'; // Ensure null termination + return set_option_impl(key, arr); + } + /** + * Retrieves an option value by key and type. + * + * @tparam T The expected type of the option value (bool, int, or const char*) + * @tparam KeyLen The length of the key string (automatically deduced) + * @param key The option key to look up + * @param out Reference to store the retrieved value + * @return Error::Ok if found and type matches, Error::NotFound if key doesn't + * exist, Error::InvalidArgument if type doesn't match + */ + template + Error get_option(const char (&key)[KeyLen], T& out) const { + static_assert(KeyLen <= kMaxOptionKeyLength, "Option key is too long"); + for (size_t i = 0; i < size_; ++i) { + if (std::strcmp(options_[i].key, key) == 0) { + // Special handling for string (convert array to const char*) + if constexpr (std::is_same_v) { + if (auto* arr = std::get_if>( + &options_[i].value)) { + out = arr->data(); // Return pointer to stored array + return Error::Ok; + } + } + // Default handling for bool/int + else if (auto* val = std::get_if(&options_[i].value)) { + out = *val; + return Error::Ok; + } + return Error::InvalidArgument; + } + } + return Error::NotFound; + } + + private: + BackendOption options_[MaxCapacity]{}; // Storage for backend options + size_t size_; // Current number of options + + /** + * Internal implementation for setting option values. + * Handles both updating existing options and adding new ones. + * + * @tparam T The type of the value (bool, int, or const char*) + * @param key The option key + * @param value The value to set + * @return Error::Ok on success, Error::InvalidArgument if storage is full + */ + template + Error set_option_impl(const char* key, T value) { + // Update existing if found + for (size_t i = 0; i < size_; ++i) { + if (strcmp(options_[i].key, key) == 0) { + options_[i].value = value; + return Error::Ok; + } + } + if (size_ < MaxCapacity) { + BackendOption new_option; + const size_t key_len = std::strlen(key); + const size_t copy_len = std::min(key_len, kMaxOptionKeyLength - 1); + std::memcpy(new_option.key, key, copy_len); + new_option.key[copy_len] = '\0'; + new_option.value = value; // Restored value assignment + options_[size_++] = new_option; // Store option and increment size + return Error::Ok; + } + return Error::InvalidArgument; + } +}; + +} // namespace runtime +} // namespace executorch diff --git a/runtime/backend/targets.bzl b/runtime/backend/targets.bzl index d2187afb5fc..93bc85d014f 100644 --- a/runtime/backend/targets.bzl +++ b/runtime/backend/targets.bzl @@ -17,6 +17,7 @@ def define_common_targets(): exported_headers = [ "backend_execution_context.h", "backend_init_context.h", + "options.h", "interface.h", ], preprocessor_flags = ["-DUSE_ATEN_LIB"] if aten_mode else [], diff --git a/runtime/backend/test/backend_options_test.cpp b/runtime/backend/test/backend_options_test.cpp new file mode 100644 index 00000000000..313cac6f143 --- /dev/null +++ b/runtime/backend/test/backend_options_test.cpp @@ -0,0 +1,164 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include + +using namespace ::testing; +using executorch::runtime::BackendOptions; +using executorch::runtime::Error; + +class BackendOptionsTest : public ::testing::Test { + protected: + void SetUp() override { + // Since these tests cause ET_LOG to be called, the PAL must be initialized + // first. + executorch::runtime::runtime_init(); + } + BackendOptions<5> options; // Capacity of 5 for testing limits +}; + +// Test basic string functionality +TEST_F(BackendOptionsTest, HandlesStringOptions) { + // Set and retrieve valid string + options.set_option("backend_type", "GPU"); + const char* result = nullptr; + EXPECT_EQ(options.get_option("backend_type", result), Error::Ok); + EXPECT_STREQ(result, "GPU"); + + // Update existing key + options.set_option("backend_type", "CPU"); + EXPECT_EQ(options.get_option("backend_type", result), Error::Ok); + EXPECT_STREQ(result, "CPU"); +} + +// Test boolean options +TEST_F(BackendOptionsTest, HandlesBoolOptions) { + options.set_option("debug", true); + bool debug = false; + EXPECT_EQ(options.get_option("debug", debug), Error::Ok); + EXPECT_TRUE(debug); + + // Test false value + options.set_option("verbose", false); + EXPECT_EQ(options.get_option("verbose", debug), Error::Ok); + EXPECT_FALSE(debug); +} + +// Test integer options +TEST_F(BackendOptionsTest, HandlesIntOptions) { + options.set_option("num_threads", 256); + int num_threads = 0; + EXPECT_EQ(options.get_option("num_threads", num_threads), Error::Ok); + EXPECT_EQ(num_threads, 256); +} + +// Test error conditions +TEST_F(BackendOptionsTest, HandlesErrors) { + // Non-existent key + bool dummy_bool; + EXPECT_EQ(options.get_option("missing", dummy_bool), Error::NotFound); + + // Type mismatch + options.set_option("threshold", 100); + const char* dummy_str = nullptr; + EXPECT_EQ(options.get_option("threshold", dummy_str), Error::InvalidArgument); + + // Null value handling, should expect failure + ET_EXPECT_DEATH( + options.set_option("nullable", static_cast(nullptr)), ""); +} + +// Test type-specific keys +TEST_F(BackendOptionsTest, EnforcesKeyTypes) { + // Same key name - later set operations overwrite earlier ones + options.set_option("flag", true); + options.set_option("flag", 123); // Overwrites the boolean entry + + bool bval; + int ival; + + // Boolean get should fail - type was overwritten to INT + EXPECT_EQ(options.get_option("flag", bval), Error::InvalidArgument); + + // Integer get should succeed with correct value + EXPECT_EQ(options.get_option("flag", ival), Error::Ok); + EXPECT_EQ(ival, 123); +} + +TEST_F(BackendOptionsTest, MutableOption) { + int ival; + options.set_option("flag", 0); + // Integer get should succeed with correct value + EXPECT_EQ(options.get_option("flag", ival), Error::Ok); + EXPECT_EQ(ival, 0); + + options.view()[0].value = 123; // Overwrites the entry + + // Integer get should succeed with the updated value + EXPECT_EQ(options.get_option("flag", ival), Error::Ok); + EXPECT_EQ(ival, 123); +} + +// Test copy constructor +TEST_F(BackendOptionsTest, CopyConstructor) { + // Set up original option + options.set_option("debug", true); + + // Create copy using copy constructor + BackendOptions<5> copied_options(options); + + // Verify option was copied correctly + bool debug_val; + EXPECT_EQ(copied_options.get_option("debug", debug_val), Error::Ok); + EXPECT_TRUE(debug_val); + + // Verify independence - modifying original doesn't affect copy + options.set_option("debug", false); + EXPECT_EQ(copied_options.get_option("debug", debug_val), Error::Ok); + EXPECT_TRUE(debug_val); // Should still be true in copy + + // Verify independence - modifying copy doesn't affect original + copied_options.set_option("debug", false); + EXPECT_EQ(options.get_option("debug", debug_val), Error::Ok); + EXPECT_FALSE(debug_val); // Should be false in original +} + +// Test copy assignment operator +TEST_F(BackendOptionsTest, CopyAssignmentOperator) { + // Set up original option + options.set_option("enable_profiling", true); + + // Create another options object and assign to it + BackendOptions<5> assigned_options; + assigned_options.set_option("temp_option", false); // Add something first + + assigned_options = options; + + // Verify option was copied correctly + bool profiling_val; + EXPECT_EQ( + assigned_options.get_option("enable_profiling", profiling_val), + Error::Ok); + EXPECT_TRUE(profiling_val); + + // Verify the temp_option was overwritten (not present in assigned object) + bool temp_val; + EXPECT_EQ( + assigned_options.get_option("temp_option", temp_val), Error::NotFound); + + // Verify independence - modifying original doesn't affect assigned copy + options.set_option("enable_profiling", false); + EXPECT_EQ( + assigned_options.get_option("enable_profiling", profiling_val), + Error::Ok); + EXPECT_TRUE(profiling_val); // Should still be true in assigned copy +} diff --git a/runtime/backend/test/targets.bzl b/runtime/backend/test/targets.bzl index 9ea585f650c..916fa3a3b98 100644 --- a/runtime/backend/test/targets.bzl +++ b/runtime/backend/test/targets.bzl @@ -1,7 +1,17 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + def define_common_targets(): """Defines targets that should be shared between fbcode and xplat. The directory containing this targets.bzl file should also contain both TARGETS and BUCK files that call this function. """ - pass + runtime.cxx_test( + name = "backend_options_test", + srcs = ["backend_options_test.cpp"], + deps = [ + "//executorch/runtime/core:core", + "//executorch/runtime/backend:interface", + "//executorch/test/utils:utils", + ], + ) From a95918ab77bf974f44b65620ae175ef7b816d11a Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Mon, 23 Jun 2025 21:20:30 -0700 Subject: [PATCH 2/3] [2/N] Add option context Pull Request resolved: https://github.com/pytorch/executorch/pull/11390 For future needs without breacking API BC, in case we need to pass more information to the update API ghstack-source-id: 292257886 Differential Revision: [D75919212](https://our.internmc.facebook.com/intern/diff/D75919212/) --- runtime/backend/backend_option_context.h | 34 ++++++++++++++++++++++++ runtime/backend/targets.bzl | 1 + 2 files changed, 35 insertions(+) create mode 100644 runtime/backend/backend_option_context.h diff --git a/runtime/backend/backend_option_context.h b/runtime/backend/backend_option_context.h new file mode 100644 index 00000000000..b3266e60732 --- /dev/null +++ b/runtime/backend/backend_option_context.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include +#include + +namespace executorch { +namespace ET_RUNTIME_NAMESPACE { +/** + * BackendOptionContext will be used to inject runtime info for to initialize + * delegate. + */ +class BackendOptionContext final { + public: + explicit BackendOptionContext() {} +}; + +} // namespace ET_RUNTIME_NAMESPACE +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::ET_RUNTIME_NAMESPACE::BackendOptionContext; +} // namespace executor +} // namespace torch diff --git a/runtime/backend/targets.bzl b/runtime/backend/targets.bzl index 93bc85d014f..1d1f95c6c97 100644 --- a/runtime/backend/targets.bzl +++ b/runtime/backend/targets.bzl @@ -17,6 +17,7 @@ def define_common_targets(): exported_headers = [ "backend_execution_context.h", "backend_init_context.h", + "backend_option_context.h", "options.h", "interface.h", ], From d74d2e79bcb1a4e7f2fb4be2e48a9cf2dfedefc3 Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Mon, 23 Jun 2025 21:20:32 -0700 Subject: [PATCH 3/3] [3/N] Add get_option/set_option function in backend interface Pull Request resolved: https://github.com/pytorch/executorch/pull/11391 Add update function in backend interface class. The update function will receive the backend options from dispatched by the ET runtime. ET runtime's logic: loop over each backend and it's corresponding backend options, dispatch the backend options to the corresponding backend Next step, will add update API in the method and then module ghstack-source-id: 292257883 @exported-using-ghexport Differential Revision: [D75919242](https://our.internmc.facebook.com/intern/diff/D75919242/) --- runtime/backend/interface.h | 33 ++ .../test/backend_interface_update_test.cpp | 287 ++++++++++++++++++ runtime/backend/test/targets.bzl | 9 + 3 files changed, 329 insertions(+) create mode 100644 runtime/backend/test/backend_interface_update_test.cpp diff --git a/runtime/backend/interface.h b/runtime/backend/interface.h index 95705d48f92..e6a4c2fb8e9 100644 --- a/runtime/backend/interface.h +++ b/runtime/backend/interface.h @@ -12,6 +12,8 @@ #include #include +#include +#include #include #include #include @@ -99,6 +101,37 @@ class BackendInterface { DelegateHandle* handle, EValue** args) const = 0; + /** + * Responsible update the backend status, if any. The backend options are + * passed in by users, and the backend can update its internal status based on + * the options. + * + * @param[in] context Runtime context if any. Currently it's not used. + * @param[in] args A list of BackendOptions passed in by users. + * @retval Error::Ok if successful. + */ + ET_NODISCARD virtual Error set_option( + __ET_UNUSED BackendOptionContext& context, + const executorch::runtime::Span& backend_options) { + return Error::Ok; + }; + + /** + * Responsible update the backend status, if any. The backend options are + * passed in by users, and the backend can update its internal status based on + * the options. + * + * @param[in] context Runtime context if any. Currently it's not used. + * @param[in] args A list of BackendOptions passed in by users, that will be + * filled by the backend + * @retval Error::Ok if successful. + */ + ET_NODISCARD virtual Error get_option( + __ET_UNUSED BackendOptionContext& context, + executorch::runtime::Span& backend_options) { + return Error::Ok; + }; + /** * Responsible for destroying a handle, if it's required for some backend. * It may be needed for some backends. For example, resources associated with diff --git a/runtime/backend/test/backend_interface_update_test.cpp b/runtime/backend/test/backend_interface_update_test.cpp new file mode 100644 index 00000000000..27dc284af5e --- /dev/null +++ b/runtime/backend/test/backend_interface_update_test.cpp @@ -0,0 +1,287 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include + +using namespace ::testing; +using executorch::runtime::ArrayRef; +using executorch::runtime::Backend; +using executorch::runtime::BackendExecutionContext; +using executorch::runtime::BackendInitContext; +using executorch::runtime::BackendInterface; +using executorch::runtime::BackendOption; +using executorch::runtime::BackendOptionContext; +using executorch::runtime::BackendOptions; +using executorch::runtime::CompileSpec; +using executorch::runtime::DelegateHandle; +using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::FreeableBuffer; +using executorch::runtime::get_backend_class; +using executorch::runtime::MemoryAllocator; +using executorch::runtime::Result; + +class MockBackend : public BackendInterface { + public: + ~MockBackend() override = default; + + bool is_available() const override { + return true; + } + + Result init( + __ET_UNUSED BackendInitContext& context, + __ET_UNUSED FreeableBuffer* processed, + __ET_UNUSED ArrayRef compile_specs) const override { + init_called = true; + return nullptr; + } + + Error execute( + __ET_UNUSED BackendExecutionContext& context, + __ET_UNUSED DelegateHandle* handle, + __ET_UNUSED EValue** args) const override { + execute_count++; + return Error::Ok; + } + + Error set_option( + __ET_UNUSED BackendOptionContext& context, + const executorch::runtime::Span& backend_options) + override { + set_option_count++; + int success_update = 0; + for (const auto& backend_option : backend_options) { + if (strcmp(backend_option.key, "Backend") == 0) { + if (std::holds_alternative>( + backend_option.value)) { + // Store the value in our member variable + const auto& arr = + std::get>(backend_option.value); + target_backend = std::string(arr.data()); + success_update++; + } + } else if (strcmp(backend_option.key, "NumberOfThreads") == 0) { + if (std::holds_alternative(backend_option.value)) { + num_threads = std::get(backend_option.value); + success_update++; + } + } else if (strcmp(backend_option.key, "Debug") == 0) { + if (std::holds_alternative(backend_option.value)) { + debug = std::get(backend_option.value); + success_update++; + } + } + } + if (success_update == backend_options.size()) { + return Error::Ok; + } + return Error::InvalidArgument; + } + + // Mutable allows modification in const methods + mutable std::optional target_backend; + mutable int num_threads = 0; + mutable bool debug = false; + + // State tracking + mutable bool init_called = false; + mutable int execute_count = 0; + mutable int set_option_count = 0; +}; + +class BackendInterfaceUpdateTest : public ::testing::Test { + protected: + void SetUp() override { + // Since these tests cause ET_LOG to be called, the PAL must be initialized + // first. + executorch::runtime::runtime_init(); + mock_backend = std::make_unique(); + // static Error register_success = register_executor_backend(); + } + + std::unique_ptr mock_backend; + BackendOptions<5> options; +}; + +TEST_F(BackendInterfaceUpdateTest, HandlesInvalidOption) { + BackendOptionContext context; + + // Test invalid key case + std::array value_array{"None"}; + BackendOption invalid_option{"InvalidKey", value_array}; + + Error err = mock_backend->set_option(context, invalid_option); + EXPECT_EQ(err, Error::InvalidArgument); +} + +TEST_F(BackendInterfaceUpdateTest, HandlesStringOption) { + BackendOptionContext context; + options.set_option("Backend", "GPU"); + // // Create a backend option to pass to update + + EXPECT_EQ(mock_backend->target_backend, std::nullopt); + + // Test successful update + Error err = mock_backend->set_option(context, options.view()); + EXPECT_EQ(err, Error::Ok); + + EXPECT_EQ(mock_backend->target_backend, "GPU"); +} + +TEST_F(BackendInterfaceUpdateTest, HandlesIntOption) { + // Check the default num_threads value is 0 + EXPECT_EQ(mock_backend->debug, false); + // Create a mock context (needs to be defined or mocked) + BackendOptionContext context; + + int expected_num_threads = 4; + + // Create a backend option to pass to update + options.set_option("NumberOfThreads", expected_num_threads); + + // Test successful update + Error err = mock_backend->set_option(context, options.view()); + EXPECT_EQ(err, Error::Ok); + EXPECT_EQ(mock_backend->num_threads, expected_num_threads); +} + +TEST_F(BackendInterfaceUpdateTest, HandlesBoolOption) { + // Check the default num_threads value is 0 + EXPECT_EQ(mock_backend->debug, false); + // Create a mock context (needs to be defined or mocked) + BackendOptionContext context; + + options.set_option("Debug", true); + + // Test successful update + Error err = mock_backend->set_option(context, options.view()); + EXPECT_EQ(err, Error::Ok); + + EXPECT_EQ(mock_backend->debug, true); +} + +TEST_F(BackendInterfaceUpdateTest, HandlesMultipleOptions) { + // Check the default num_threads value is 0 + EXPECT_EQ(mock_backend->debug, false); + // Create a mock context (needs to be defined or mocked) + BackendOptionContext context; + + options.set_option("Debug", true); + options.set_option("NumberOfThreads", 4); + options.set_option("Backend", "GPU"); + + // Test successful update + Error err = mock_backend->set_option(context, options.view()); + EXPECT_EQ(err, Error::Ok); + + EXPECT_EQ(mock_backend->debug, true); + EXPECT_EQ(mock_backend->num_threads, 4); + EXPECT_EQ(mock_backend->target_backend, "GPU"); +} + +TEST_F(BackendInterfaceUpdateTest, UpdateBeforeInit) { + BackendOptionContext option_context; + MemoryAllocator memory_allocator{MemoryAllocator(0, nullptr)}; + + BackendInitContext init_context(&memory_allocator); + + // Create backend option + options.set_option("Backend", "GPU"); + + // Update before init + Error err = mock_backend->set_option(option_context, options.view()); + EXPECT_EQ(err, Error::Ok); + + // Now call init + FreeableBuffer* processed = nullptr; // Not used in mock + ArrayRef compile_specs; // Empty + auto handle_or_error = + mock_backend->init(init_context, processed, compile_specs); + EXPECT_EQ(handle_or_error.error(), Error::Ok); + + // Verify state + EXPECT_TRUE(mock_backend->init_called); + EXPECT_EQ(mock_backend->set_option_count, 1); + EXPECT_EQ(mock_backend->execute_count, 0); + ASSERT_TRUE(mock_backend->target_backend.has_value()); + EXPECT_STREQ(mock_backend->target_backend.value().c_str(), "GPU"); +} + +TEST_F(BackendInterfaceUpdateTest, UpdateAfterInitBeforeExecute) { + BackendOptionContext option_context; + MemoryAllocator init_memory_allocator{MemoryAllocator(0, nullptr)}; + BackendInitContext init_context(&init_memory_allocator); + BackendExecutionContext execute_context; + + // First call init + FreeableBuffer* processed = nullptr; + ArrayRef compile_specs; + auto handle_or_error = + mock_backend->init(init_context, processed, compile_specs); + EXPECT_TRUE(handle_or_error.ok()); + + // Verify init called but execute not called + EXPECT_TRUE(mock_backend->init_called); + EXPECT_EQ(mock_backend->execute_count, 0); + + // Now update + options.set_option("Backend", "CPU"); + Error err = mock_backend->set_option(option_context, options.view()); + EXPECT_EQ(err, Error::Ok); + + // Now execute + DelegateHandle* handle = handle_or_error.get(); + EValue** args = nullptr; // Not used in mock + err = mock_backend->execute(execute_context, handle, args); + EXPECT_EQ(err, Error::Ok); + + // Verify state + EXPECT_EQ(mock_backend->set_option_count, 1); + EXPECT_EQ(mock_backend->execute_count, 1); + ASSERT_TRUE(mock_backend->target_backend.has_value()); + EXPECT_STREQ(mock_backend->target_backend.value().c_str(), "CPU"); +} + +TEST_F(BackendInterfaceUpdateTest, UpdateBetweenExecutes) { + BackendOptionContext option_context; + MemoryAllocator init_memory_allocator{MemoryAllocator(0, nullptr)}; + BackendInitContext init_context(&init_memory_allocator); + BackendExecutionContext execute_context; + + // Initialize + FreeableBuffer* processed = nullptr; + ArrayRef compile_specs; + auto handle_or_error = + mock_backend->init(init_context, processed, compile_specs); + EXPECT_TRUE(handle_or_error.ok()); + DelegateHandle* handle = handle_or_error.get(); + + // First execute + EValue** args = nullptr; + Error err = mock_backend->execute(execute_context, handle, args); + EXPECT_EQ(err, Error::Ok); + + // Update between executes + options.set_option("Backend", "NPU"); + err = mock_backend->set_option(option_context, options.view()); + EXPECT_EQ(err, Error::Ok); + + // Second execute + err = mock_backend->execute(execute_context, handle, args); + EXPECT_EQ(err, Error::Ok); + + // Verify state + EXPECT_EQ(mock_backend->set_option_count, 1); + EXPECT_EQ(mock_backend->execute_count, 2); + ASSERT_TRUE(mock_backend->target_backend.has_value()); + EXPECT_STREQ(mock_backend->target_backend.value().c_str(), "NPU"); +} diff --git a/runtime/backend/test/targets.bzl b/runtime/backend/test/targets.bzl index 916fa3a3b98..f9e5c1e0de2 100644 --- a/runtime/backend/test/targets.bzl +++ b/runtime/backend/test/targets.bzl @@ -15,3 +15,12 @@ def define_common_targets(): "//executorch/test/utils:utils", ], ) + + runtime.cxx_test( + name = "backend_interface_update_test", + srcs = ["backend_interface_update_test.cpp"], + deps = [ + "//executorch/runtime/core:core", + "//executorch/runtime/backend:interface", + ], + )