diff --git a/backends/qualcomm/CMakeLists.txt b/backends/qualcomm/CMakeLists.txt index 33f150413a3..f2e40f92caf 100644 --- a/backends/qualcomm/CMakeLists.txt +++ b/backends/qualcomm/CMakeLists.txt @@ -116,6 +116,7 @@ add_library(qcir INTERFACE qcir_schema_output) add_library(qcir_utils STATIC) add_library(qnn_backend STATIC) add_library(qnn_backend_cache STATIC) +add_library(qnn_backend_options STATIC) add_library(qnn_context STATIC) add_library(qnn_custom_protocol STATIC) add_library(qnn_dlc_manager STATIC) @@ -159,6 +160,7 @@ target_link_libraries( qnn_backend PRIVATE qnn_implementation qnn_logger qnn_op_package_manager ) target_link_libraries(qnn_custom_protocol PRIVATE qnn_logger) +target_link_libraries(qnn_backend_options PRIVATE qnn_schema) target_link_libraries( qnn_device PRIVATE qnn_executorch_logging qnn_implementation qnn_logger ) @@ -197,7 +199,7 @@ target_link_libraries( ) target_link_libraries( qnn_executorch_backend PRIVATE qnn_executorch_header qnn_schema qnn_manager - executorch_core extension_tensor + executorch_core extension_tensor qnn_backend_options ) set_target_properties( qnn_executorch_backend PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'" @@ -261,6 +263,7 @@ if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64") qnn_executorch_header executorch extension_tensor + qnn_backend_options ) target_link_libraries( PyQnnWrapperAdaptor PRIVATE pybind11::module pybind11::lto wrappers diff --git a/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py b/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py index 1ee71d42bd4..6c29924defa 100644 --- a/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py +++ b/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py @@ -105,7 +105,7 @@ def call(self, graph_module: torch.fx.GraphModule): padding = [0] + node.args[4] if num_args > 4 else [0, 0] if node.target == torch.ops.aten.conv1d.default: dilation = [1] + node.args[5] if num_args > 5 else [1, 1] - groups = node.args[6] if num_args > 5 else 1 + groups = node.args[6] if num_args > 6 else 1 conv_args = ( qdq_node_after_unsqueeze, node.args[1], diff --git a/backends/qualcomm/runtime/CMakeLists.txt b/backends/qualcomm/runtime/CMakeLists.txt index eb31bee7a53..1a35ec8366f 100644 --- a/backends/qualcomm/runtime/CMakeLists.txt +++ b/backends/qualcomm/runtime/CMakeLists.txt @@ -28,6 +28,13 @@ target_sources( PRIVATE ${CMAKE_CURRENT_LIST_DIR}/QnnManager.cpp ) +# qnn_backend_options +target_sources( + qnn_backend_options + INTERFACE ${CMAKE_CURRENT_LIST_DIR}/QnnBackendOptions.h + PRIVATE ${CMAKE_CURRENT_LIST_DIR}/QnnBackendOptions.cpp +) + # logging target_sources( qnn_executorch_logging diff --git a/backends/qualcomm/runtime/QnnBackendOptions.cpp b/backends/qualcomm/runtime/QnnBackendOptions.cpp new file mode 100644 index 00000000000..17e9975008d --- /dev/null +++ b/backends/qualcomm/runtime/QnnBackendOptions.cpp @@ -0,0 +1,50 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include + +namespace executorch { +namespace backends { +namespace qnn { + +using namespace qnn_delegate; + +template +T get_option(T aot_option) { + executorch::runtime::Error status; + executorch::runtime::BackendOption backend_option; + + if constexpr (std::is_same_v) { + backend_option = {QNN_RUNTIME_LOG_LEVEL, -1}; + } else if constexpr (std::is_same_v) { + backend_option = {QNN_RUNTIME_HTP_PERFORMANCE_MODE, -1}; + } else if constexpr (std::is_same_v) { + backend_option = {QNN_RUNTIME_PROFILE_LEVEL, -1}; + } + // This will call get_option under runtime backend interface + status = get_option(QNN_BACKEND, backend_option); + + if (status != executorch::runtime::Error::Ok) { + return aot_option; + } else { + return static_cast(std::get(backend_option.value)); + } +} + +// Explicit instantiations +template QnnExecuTorchLogLevel get_option( + QnnExecuTorchLogLevel); +template QnnExecuTorchHtpPerformanceMode get_option< + QnnExecuTorchHtpPerformanceMode>(QnnExecuTorchHtpPerformanceMode); +template QnnExecuTorchProfileLevel get_option( + QnnExecuTorchProfileLevel); + +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/QnnBackendOptions.h b/backends/qualcomm/runtime/QnnBackendOptions.h new file mode 100644 index 00000000000..a601a4202c0 --- /dev/null +++ b/backends/qualcomm/runtime/QnnBackendOptions.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include + +namespace executorch { +namespace backends { +namespace qnn { + +/** + * @brief Storing runtime option value. + * @param is_set True when user calls set_option api to set option, else False. + */ +struct RuntimeOption { + bool is_set; + executorch::runtime::OptionValue value; +}; + +/** + * @brief + * Get the backend option. + * This method checks both AOT option and runtime option. + * If runtime option is provided, it will have a higher priority. + * + * @param aot_option The flatbuffer option under qc_compiler_spec.fbs. + */ + +template +T get_option(T aot_option); + +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/QnnExecuTorch.h b/backends/qualcomm/runtime/QnnExecuTorch.h index 2ca0cd61cd5..889ac516a36 100644 --- a/backends/qualcomm/runtime/QnnExecuTorch.h +++ b/backends/qualcomm/runtime/QnnExecuTorch.h @@ -16,6 +16,11 @@ #include #endif +#define QNN_BACKEND "QnnBackend" +#define QNN_RUNTIME_LOG_LEVEL "qnn_runtime_log_level" +#define QNN_RUNTIME_HTP_PERFORMANCE_MODE "qnn_runtime_htp_performance_mode" +#define QNN_RUNTIME_PROFILE_LEVEL "qnn_runtime_profile_level" + #ifdef __cplusplus extern "C" { #endif // __cplusplus diff --git a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp index 01bf13603d6..b905f9e46c3 100644 --- a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp +++ b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp @@ -8,10 +8,12 @@ #include #include +#include #include #include #include - +#include +#include namespace executorch { namespace backends { namespace qnn { @@ -189,6 +191,77 @@ void QnnExecuTorchBackend::destroy(DelegateHandle* handle) const { } } +executorch::runtime::Error QnnExecuTorchBackend::set_option( + executorch::runtime::BackendOptionContext& context, + const executorch::runtime::Span& + backend_options) { + std::lock_guard guard(runtime_option_mutex_); + size_t matches = backend_options.size(); + for (const auto& option : backend_options) { + if (strcmp(option.key, QNN_RUNTIME_LOG_LEVEL) == 0) { + if (auto* val = std::get_if(&option.value)) { + qnn_runtime_log_level_.value = *val; + qnn_runtime_log_level_.is_set = true; + } + } else if (strcmp(option.key, QNN_RUNTIME_HTP_PERFORMANCE_MODE) == 0) { + if (auto* val = std::get_if(&option.value)) { + qnn_runtime_performance_mode_.value = *val; + qnn_runtime_performance_mode_.is_set = true; + } + } else if (strcmp(option.key, QNN_RUNTIME_PROFILE_LEVEL) == 0) { + if (auto* val = std::get_if(&option.value)) { + qnn_runtime_profile_level_.value = *val; + qnn_runtime_profile_level_.is_set = true; + } + } else { + ET_LOG( + Error, + "Unable to set the following runtime option for QnnExecuTorchBackend: %s.", + option.key); + matches--; + } + } + + ET_CHECK_OR_RETURN_ERROR( + matches == backend_options.size(), + Internal, + "Some set options are not supported by QnnExecuTorchBackend. %zu options provided but only %zu is supported.", + backend_options.size(), + matches); + + return Error::Ok; +} + +executorch::runtime::Error QnnExecuTorchBackend::get_option( + executorch::runtime::BackendOptionContext& context, + executorch::runtime::Span& + backend_options) { + size_t matches = backend_options.size(); + for (size_t i = 0; i < backend_options.size(); ++i) { + // Set the value to what was stored by set_option + if (strcmp(backend_options[i].key, QNN_RUNTIME_LOG_LEVEL) == 0 && + qnn_runtime_log_level_.is_set) { + backend_options[i].value = qnn_runtime_log_level_.value; + } else if ( + strcmp(backend_options[i].key, QNN_RUNTIME_HTP_PERFORMANCE_MODE) == 0 && + qnn_runtime_performance_mode_.is_set) { + backend_options[i].value = qnn_runtime_performance_mode_.value; + } else if ( + strcmp(backend_options[i].key, QNN_RUNTIME_PROFILE_LEVEL) == 0 && + qnn_runtime_profile_level_.is_set) { + backend_options[i].value = qnn_runtime_profile_level_.value; + } else { + // either runtime never called set_option or key does not exist + matches--; + } + } + + if (matches != backend_options.size()) { + return Error::Internal; + } + return Error::Ok; +} + bool QnnExecuTorchBackend::is_available() const { return true; } @@ -214,7 +287,7 @@ void QnnExecuTorchBackend::erase_cached_delegate( namespace { auto cls = QnnExecuTorchBackend(); -executorch::runtime::Backend backend{"QnnBackend", &cls}; +executorch::runtime::Backend backend{QNN_BACKEND, &cls}; static auto success_with_compiler = register_backend(backend); } // namespace } // namespace qnn diff --git a/backends/qualcomm/runtime/QnnExecuTorchBackend.h b/backends/qualcomm/runtime/QnnExecuTorchBackend.h index e83ec6b13b0..f25230045a6 100644 --- a/backends/qualcomm/runtime/QnnExecuTorchBackend.h +++ b/backends/qualcomm/runtime/QnnExecuTorchBackend.h @@ -7,6 +7,7 @@ */ #pragma once +#include #include #include #include @@ -34,6 +35,16 @@ class QnnExecuTorchBackend final executorch::runtime::DelegateHandle* handle, executorch::runtime::EValue** args) const override; + ET_NODISCARD executorch::runtime::Error set_option( + executorch::runtime::BackendOptionContext& context, + const executorch::runtime::Span& + backend_options) override; + + executorch::runtime::Error get_option( + executorch::runtime::BackendOptionContext& context, + executorch::runtime::Span& + backend_options) override; + void destroy(executorch::runtime::DelegateHandle* handle) const override; bool is_available() const override; @@ -45,10 +56,15 @@ class QnnExecuTorchBackend final void erase_cached_delegate(executorch::runtime::DelegateHandle* handle) const; mutable std::mutex mutex_; + mutable std::mutex runtime_option_mutex_; mutable std::unordered_map delegate_map_; mutable std::unordered_map delegate_map_rev_; + + RuntimeOption qnn_runtime_log_level_{false, 0}; + RuntimeOption qnn_runtime_performance_mode_{false, 0}; + RuntimeOption qnn_runtime_profile_level_{false, 0}; }; } // namespace qnn diff --git a/backends/qualcomm/runtime/QnnManager.cpp b/backends/qualcomm/runtime/QnnManager.cpp index 0dd0470a2b0..be9e5fcd58f 100644 --- a/backends/qualcomm/runtime/QnnManager.cpp +++ b/backends/qualcomm/runtime/QnnManager.cpp @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #include #include @@ -63,7 +64,8 @@ QnnManager::QnnManager( options->backend_options()->backend_type(); std::string library_path = options->library_path()->str(); - if (options->log_level() >= QnnExecuTorchLogLevel::kLogLevelInfo) { + if (get_option(options_->log_level()) >= + QnnExecuTorchLogLevel::kLogLevelInfo) { QNN_EXECUTORCH_LOG_INFO( "soc_model in soc_info: %s", EnumNameQcomChipset(options_->soc_info()->soc_model())); @@ -75,10 +77,12 @@ QnnManager::QnnManager( QNN_EXECUTORCH_LOG_INFO("library_path: %s", library_path.c_str()); QNN_EXECUTORCH_LOG_INFO("dump intermediate outputs: %s", IsTensorDump()); QNN_EXECUTORCH_LOG_INFO( - "log_level: %s", EnumNameQnnExecuTorchLogLevel(options_->log_level())); + "log_level: %s", + EnumNameQnnExecuTorchLogLevel(get_option(options_->log_level()))); QNN_EXECUTORCH_LOG_INFO( "profile_level: %s", - EnumNameQnnExecuTorchProfileLevel(options_->profile_level())); + EnumNameQnnExecuTorchProfileLevel( + get_option(options_->profile_level()))); QNN_EXECUTORCH_LOG_INFO( "the size of qnn context binary: %d", qnn_executorch_context_binary.nbytes); @@ -202,7 +206,8 @@ Error QnnManager::RegisterIonMem( return Error::Internal; } else if (backend_params_ptr_->qnn_mem_manager_ptr_->IsRegistered( tensor_wrapper->GetMemHandle(), data_ptr)) { - if (options_->log_level() >= QnnExecuTorchLogLevel::kLogLevelInfo) + if (get_option(options_->log_level()) >= + QnnExecuTorchLogLevel::kLogLevelInfo) QNN_EXECUTORCH_LOG_INFO( "Tensor name %s has been registered shared memory.", tensor_wrapper->GetName().c_str()); @@ -231,7 +236,8 @@ Error QnnManager::RegisterCustomMem( const std::shared_ptr& tensor_wrapper) { if (backend_params_ptr_->qnn_mem_manager_ptr_->IsRegistered( tensor_wrapper->GetMemHandle(), data_ptr)) { - if (options_->log_level() >= QnnExecuTorchLogLevel::kLogLevelInfo) + if (get_option(options_->log_level()) >= + QnnExecuTorchLogLevel::kLogLevelInfo) QNN_EXECUTORCH_LOG_INFO( "Tensor name %s has been registered shared memory.", tensor_wrapper->GetName().c_str()); @@ -251,7 +257,8 @@ Error QnnManager::RegisterCustomMem( Qnn_MemHandle_t pre_registered_handle = backend_params_ptr_->qnn_mem_manager_ptr_->GetPreRegisteredHandle(info); if (pre_registered_handle != nullptr) { - if (options_->log_level() >= QnnExecuTorchLogLevel::kLogLevelInfo) { + if (get_option(options_->log_level()) >= + QnnExecuTorchLogLevel::kLogLevelInfo) { QNN_EXECUTORCH_LOG_INFO( "Tensor name %s found a pre-registered memHandle.", tensor_wrapper->GetName().c_str()); @@ -295,7 +302,7 @@ Error QnnManager::Init() { ET_CHECK_OR_RETURN_ERROR( LoadQnnLibrary() == Error::Ok, Internal, "Fail to load Qnn library"); logger_ = std::make_unique( - qnn_loaded_backend_, LoggingCallback, options_->log_level()); + qnn_loaded_backend_, LoggingCallback, get_option(options_->log_level())); std::vector graph_names; for (auto name : *options_->graph_name()) { graph_names.emplace_back(name->str()); @@ -492,7 +499,8 @@ Error QnnManager::ProfileExecuteData( const std::string& graph_name, executorch::runtime::EventTracer* event_tracer) { Qnn_ErrorHandle_t error = QNN_SUCCESS; - if (options_->profile_level() != QnnExecuTorchProfileLevel::kProfileOff) { + if (get_option(options_->profile_level()) != + QnnExecuTorchProfileLevel::kProfileOff) { error = backend_params_ptr_->qnn_graph_ptr_->ProfileExecuteData( graph_name, event_tracer); if (error != QNN_SUCCESS) { diff --git a/backends/qualcomm/runtime/backends/QnnBackendFactory.cpp b/backends/qualcomm/runtime/backends/QnnBackendFactory.cpp index 2fbb2243d8d..e7e9db6fed8 100644 --- a/backends/qualcomm/runtime/backends/QnnBackendFactory.cpp +++ b/backends/qualcomm/runtime/backends/QnnBackendFactory.cpp @@ -6,6 +6,7 @@ * LICENSE file in the root directory of this source tree. */ #include +#include #include #include namespace executorch { @@ -30,7 +31,8 @@ std::unique_ptr QnnBackendFactory::Create( if (!skel_library_dir.empty()) { setenv("ADSP_LIBRARY_PATH", skel_library_dir.c_str(), /*overwrite=*/1); } - if (options->log_level() >= QnnExecuTorchLogLevel::kLogLevelInfo) { + if (get_option(options->log_level()) >= + QnnExecuTorchLogLevel::kLogLevelInfo) { QNN_EXECUTORCH_LOG_INFO( "skel_library_dir: %s", skel_library_dir.c_str()); QNN_EXECUTORCH_LOG_INFO( @@ -42,7 +44,7 @@ std::unique_ptr QnnBackendFactory::Create( QNN_EXECUTORCH_LOG_INFO( "performance_mode in htp_options: %s", EnumNameQnnExecuTorchHtpPerformanceMode( - htp_options->performance_mode())); + get_option(htp_options->performance_mode()))); QNN_EXECUTORCH_LOG_INFO( "precision in htp_options: %s", EnumNameQnnExecuTorchHtpPrecision(htp_options->precision())); @@ -75,13 +77,13 @@ std::unique_ptr QnnBackendFactory::Create( implementation, backend_params->qnn_backend_ptr_.get(), backend_params->qnn_context_ptr_.get(), - options->profile_level(), + get_option(options->profile_level()), options->soc_info(), htp_options); backend_params->qnn_mem_manager_ptr_ = std::make_unique( implementation, backend_params->qnn_context_ptr_.get(), - options->log_level()); + get_option(options->log_level())); backend_params->backend_init_state_ = BackendInitializeState::INITIALIZED; } break; case QnnExecuTorchBackendType::kGpuBackend: diff --git a/backends/qualcomm/runtime/backends/htpbackend/HtpDevice.cpp b/backends/qualcomm/runtime/backends/htpbackend/HtpDevice.cpp index 46ba3117269..35a20048fc5 100644 --- a/backends/qualcomm/runtime/backends/htpbackend/HtpDevice.cpp +++ b/backends/qualcomm/runtime/backends/htpbackend/HtpDevice.cpp @@ -396,11 +396,10 @@ Error HtpDevice::AfterCreateDevice() { QNN_GET_ERROR_CODE(error)); return Error::Internal; } - // Set vector of PowerConfigs and map it to a vector of pointers. perf_power_configs_ = SetVotePowerConfig( powerconfig_client_id_, - htp_options_->performance_mode(), + get_option(htp_options_->performance_mode()), PerformanceModeVoteType::kUpVote); perf_power_configs_ptr_ = ObtainNullTermPtrVector(perf_power_configs_); @@ -416,7 +415,7 @@ Error HtpDevice::AfterCreateDevice() { // Set Rpc polling mode rpc_power_configs_ = - SetRpcPollingPowerConfig(htp_options_->performance_mode()); + SetRpcPollingPowerConfig(get_option(htp_options_->performance_mode())); rpc_power_configs_ptr_ = ObtainNullTermPtrVector(rpc_power_configs_); htp_perf_infra_->setPowerConfig( diff --git a/backends/qualcomm/runtime/backends/htpbackend/HtpDevice.h b/backends/qualcomm/runtime/backends/htpbackend/HtpDevice.h index f75e15fc77c..9052deb6b52 100644 --- a/backends/qualcomm/runtime/backends/htpbackend/HtpDevice.h +++ b/backends/qualcomm/runtime/backends/htpbackend/HtpDevice.h @@ -7,6 +7,7 @@ */ #pragma once +#include #include #include #include @@ -55,7 +56,7 @@ class HtpDevice : public QnnDevice { void ReleasePerformanceVote(); inline bool IsPerfModeEnabled() { - return htp_options_->performance_mode() != + return get_option(htp_options_->performance_mode()) != QnnExecuTorchHtpPerformanceMode::kHtpDefault; } diff --git a/backends/qualcomm/runtime/backends/irbackend/x86_64/QnnDlcManager.cpp b/backends/qualcomm/runtime/backends/irbackend/x86_64/QnnDlcManager.cpp index 050a679e62a..280751cf160 100644 --- a/backends/qualcomm/runtime/backends/irbackend/x86_64/QnnDlcManager.cpp +++ b/backends/qualcomm/runtime/backends/irbackend/x86_64/QnnDlcManager.cpp @@ -5,6 +5,7 @@ * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ +#include #include #include @@ -51,7 +52,7 @@ Error QnnDlcManager::Create() { qnn_loaded_backend_, backend_params_ptr_->qnn_backend_ptr_.get(), backend_params_ptr_->qnn_context_ptr_.get(), - options_->profile_level()); + get_option(options_->profile_level())); backend_params_ptr_->backend_init_state_ = BackendInitializeState::INITIALIZED; return backend_params_ptr_->qnn_backend_ptr_->VerifyQNNSDKVersion(); @@ -105,7 +106,7 @@ Error QnnDlcManager::SetUpDlcEnvironment(const Qnn_Version_t& coreApiVersion) { "Fail to Load Qnn IR library."); logger_ = std::make_unique( - qnn_loaded_backend_, LoggingCallback, options_->log_level()); + qnn_loaded_backend_, LoggingCallback, get_option(options_->log_level())); ET_CHECK_OR_RETURN_ERROR( Create() == Error::Ok, Internal, "Failed to load Qnn IR backend."); diff --git a/backends/qualcomm/runtime/targets.bzl b/backends/qualcomm/runtime/targets.bzl index 1bd82f8f913..6837bece6eb 100644 --- a/backends/qualcomm/runtime/targets.bzl +++ b/backends/qualcomm/runtime/targets.bzl @@ -75,11 +75,11 @@ def define_common_targets(): "//executorch/backends/qualcomm:schema", "//executorch/backends/qualcomm/aot/ir:qcir_utils", "//executorch/backends/qualcomm/aot/wrappers:wrappers", - "//executorch/runtime/backend:interface", "//executorch/runtime/core:core", "//executorch/extension/tensor:tensor", ], exported_deps = [ + "//executorch/runtime/backend:interface", "//executorch/runtime/core/exec_aten/util:scalar_type_util", "//executorch/runtime/core:event_tracer", ], diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index d4eb3e4eac3..4ee343c19e9 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -9,6 +9,7 @@ import sys import tempfile import unittest +from functools import partial from multiprocessing.connection import Listener from pathlib import Path @@ -3054,6 +3055,104 @@ def test_qnn_backend_profile_op(self): expected_profile_events=30, ) + def test_qnn_backend_runtime_option_htp_performance(self): + backend_options = generate_htp_compiler_spec(use_fp16=True) + TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=self.chipset_table[TestQNN.model], + backend_options=backend_options, + ) + module = SimpleModel() # noqa: F405 + sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) + + def output_callback(log_msg, is_burst): + msg = log_msg.stdout + # Refer to HtpDevice.cpp for the following values + min_voltage = ( + "coreVoltageCornerMin 160" if is_burst else "coreVoltageCornerMin 80" + ) + self.assertTrue(min_voltage in msg, f"Expecting '{min_voltage} ' in log") + + burst_runtime_commands = ( + " --htp_performance_mode 2 --log_level 4" # kHtpBurst, kLogLevelVerbose + ) + self.lower_module_and_test_output( + module, + sample_input, + extra_cmds=burst_runtime_commands, + output_callback=partial(output_callback, is_burst=True), + save_inference_speed=True, + ) + burst_speed = 1000 / self.inference_speed # inferences per second + + power_saver_runtime_commands = " --htp_performance_mode 6 --log_level 4" # kHtpHighPowerSaver, kLogLevelVerbose + self.lower_module_and_test_output( + module, + sample_input, + extra_cmds=power_saver_runtime_commands, + output_callback=partial(output_callback, is_burst=False), + save_inference_speed=True, + ) + power_saver_speed = 1000 / self.inference_speed # inferences per second + + # Only need to ensure device burst is faster than high power saver + if not self.enable_x86_64: + self.assertGreater( + burst_speed, + power_saver_speed, + f"Burst mode should be faster than high power saver mode, Burst: {burst_speed} inference / second, High Power Saver: {power_saver_speed} inference /second.", + ) + + def test_qnn_backend_runtime_option_log(self): + backend_options = generate_htp_compiler_spec(use_fp16=True) + TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=self.chipset_table[TestQNN.model], + backend_options=backend_options, + ) + module = SimpleModel() # noqa: F405 + sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) + runtime_commands = " --log_level 4" # kLogLevelVerbose + + def output_callback(log_msg): + msg = log_msg.stdout + # Check log prefix, different QNN version will have slightly different message format. + self.assertTrue( + any( + sub in msg + for sub in [ + "[Qnn ExecuTorch]: QnnDsp ", + "[Qnn ExecuTorch]: ", + ] + ), + "Expecting Verbose message in log", + ) + + self.lower_module_and_test_output( + module, + sample_input, + extra_cmds=runtime_commands, + output_callback=output_callback, + ) + + def test_qnn_backend_runtime_option_profile(self): + TestQNN.enable_profile = True + backend_options = generate_htp_compiler_spec(use_fp16=True) + TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=self.chipset_table[TestQNN.model], + backend_options=backend_options, + profile=False, # Turn on using runtime command + ) + module = SimpleModel() # noqa: F405 + sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) + runtime_commands = " --profile_level 2" # kProfileDetailed + # With same model, expected_profile events for this UT should match test_qnn_backend_profile_op + self.lower_module_and_test_output( + module, + sample_input, + expected_partitions=1, + expected_profile_events=30, + extra_cmds=runtime_commands, + ) + def test_qnn_backend_shared_buffer(self): TestQNN.shared_buffer = True backend_options = generate_htp_compiler_spec( @@ -3774,6 +3873,107 @@ def test_qnn_backend_profile_op(self): expected_profile_events=30, ) + def test_qnn_backend_runtime_option_htp_performance(self): + backend_options = generate_htp_compiler_spec(use_fp16=False) + TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=self.chipset_table[TestQNN.model], + backend_options=backend_options, + ) + module = SimpleModel() # noqa: F405 + sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) + module = self.get_qdq_module(module, sample_input) + + def output_callback(log_msg, is_burst): + msg = log_msg.stdout + # Refer to HtpDevice.cpp for the following values + min_voltage = ( + "coreVoltageCornerMin 160" if is_burst else "coreVoltageCornerMin 80" + ) + self.assertTrue(min_voltage in msg, f"Expecting '{min_voltage} ' in log") + + burst_runtime_commands = ( + " --htp_performance_mode 2 --log_level 4" # kHtpBurst, kLogLevelVerbose + ) + self.lower_module_and_test_output( + module, + sample_input, + extra_cmds=burst_runtime_commands, + output_callback=partial(output_callback, is_burst=True), + save_inference_speed=True, + ) + burst_speed = 1000 / self.inference_speed # num inference per second + + power_saver_runtime_commands = " --htp_performance_mode 6 --log_level 4" # kHtpHighPowerSaver, kLogLevelVerbose + self.lower_module_and_test_output( + module, + sample_input, + extra_cmds=power_saver_runtime_commands, + output_callback=partial(output_callback, is_burst=False), + save_inference_speed=True, + ) + power_saver_speed = 1000 / self.inference_speed # num inference per second + + # Only need to ensure device burst is faster than high power saver + if not self.enable_x86_64: + self.assertGreater( + burst_speed, + power_saver_speed, + f"Burst mode should be faster than high power saver mode, Burst: {burst_speed} inference / second, High Power Saver: {power_saver_speed} inference /second.", + ) + + def test_qnn_backend_runtime_option_log(self): + backend_options = generate_htp_compiler_spec(use_fp16=False) + TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=self.chipset_table[TestQNN.model], + backend_options=backend_options, + ) + module = SimpleModel() # noqa: F405 + sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) + module = self.get_qdq_module(module, sample_input) + runtime_commands = " --log_level 4" # kLogLevelVerbose + + def output_callback(log_msg): + msg = log_msg.stdout + # Check log prefix, different QNN version will have slightly different message format. + self.assertTrue( + any( + sub in msg + for sub in [ + "[Qnn ExecuTorch]: QnnDsp ", + "[Qnn ExecuTorch]: ", + ] + ), + "Expecting Verbose message in log", + ) + + self.lower_module_and_test_output( + module, + sample_input, + extra_cmds=runtime_commands, + output_callback=output_callback, + ) + + def test_qnn_backend_runtime_option_profile(self): + TestQNN.enable_profile = True + backend_options = generate_htp_compiler_spec(use_fp16=False) + TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=self.chipset_table[TestQNN.model], + backend_options=backend_options, + profile=False, # Turn on using runtime command + ) + module = SimpleModel() # noqa: F405 + sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) + module = self.get_qdq_module(module, sample_input) + runtime_commands = " --profile_level 2" # kProfileDetailed + # With same model, expected_profile events for this UT should match test_qnn_backend_profile_op + self.lower_module_and_test_output( + module, + sample_input, + expected_partitions=1, + expected_profile_events=30, + extra_cmds=runtime_commands, + ) + def test_qnn_backend_shared_buffer(self): TestQNN.shared_buffer = True backend_options = generate_htp_compiler_spec( diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index fd2d10e2b93..43c521130a2 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -198,6 +198,8 @@ class TestQNN(unittest.TestCase): pre_gen_pte: str = "" llama_artifacts: str = "" dump_intermediate_outputs: bool = False + inference_speed: float = 0.0 + inference_speed_output_path = "outputs/inference_speed.txt" def _assert_outputs_equal(self, model_output, ref_output): self.assertTrue(len(ref_output) == len(model_output)) @@ -264,6 +266,9 @@ def verify_output( # noqa: C901 output_encodings: Tuple = (), check_io_shape: bool = False, op_package_paths: List[str] = None, + extra_cmds: str = "", + output_callback: Optional[Callable[[str], None]] = None, + save_inference_speed: bool = False, ): with tempfile.TemporaryDirectory() as tmp_dir: ( @@ -287,7 +292,9 @@ def post_process(): torch_to_numpy_dtype_dict, ) - for i, f in enumerate(sorted(os.listdir(output_dir))): + for i, f in enumerate( + sorted(f for f in os.listdir(output_dir) if f.endswith(".raw")) + ): enc = output_encodings[i] if len(output_encodings) != 0 else None dtype = ( ref_outputs[i].numpy().dtype @@ -368,6 +375,13 @@ def validate_intermediate_tensor(): ] if expected_intermediate_events != -1: cmd.append("--dump_intermediate_outputs") + cmd += extra_cmds.split() + + if save_inference_speed: + cmd += [ + "--performance_output_path", + self.inference_speed_output_path, + ] if check_io_shape: shape_info = { @@ -387,16 +401,19 @@ def validate_intermediate_tensor(): cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + text=True, env=env, cwd=tmp_dir, ) + if output_callback: + output_callback(proc) self.assertEqual( proc.returncode, 0, f"The process running qnn_executorch_runner return {proc.returncode}, " "STDOUT=\n" - f"{proc.stdout.decode('utf-8')}", + f"{proc.stdout}", ) # Verify the outputs @@ -409,6 +426,13 @@ def validate_intermediate_tensor(): if expected_intermediate_events != -1: validate_intermediate_tensor() + + if save_inference_speed: + with open( + f"{tmp_dir}/{self.inference_speed_output_path}", "r" + ) as f: + self.inference_speed = float(f.read()) + else: adb = SimpleADB( qnn_sdk=os.getenv("QNN_SDK_ROOT"), @@ -438,7 +462,12 @@ def validate_intermediate_tensor(): input_list=input_list, files=op_package_paths, ) - adb.execute(method_index=method_index) + adb.extra_cmds += extra_cmds + if save_inference_speed: + adb.extra_cmds += ( + f" --performance_output_path {self.inference_speed_output_path}" + ) + adb.execute(method_index=method_index, output_callback=output_callback) adb.pull(output_path=tmp_dir, callback=post_process) self._assert_outputs_equal(outputs, ref_outputs) @@ -451,6 +480,11 @@ def validate_intermediate_tensor(): debug_output_path, callback=validate_intermediate_tensor, ) + if save_inference_speed: + with open( + f"{tmp_dir}/{self.inference_speed_output_path}", "r" + ) as f: + self.inference_speed = float(f.read()) def lower_module_and_test_output( self, @@ -465,6 +499,9 @@ def lower_module_and_test_output( skip_node_op_set: set = None, skip_mutable_buffer: bool = False, dynamic_shapes: Dict = None, + extra_cmds: str = "", + output_callback: Optional[Callable[[str], None]] = None, + save_inference_speed: bool = False, ): delegated_program = to_edge_transform_and_lower_to_qnn( module, @@ -520,6 +557,9 @@ def lower_module_and_test_output( etrecord_path, expected_profile_events, expected_intermediate_events, + extra_cmds=extra_cmds, + output_callback=output_callback, + save_inference_speed=save_inference_speed, ) def get_qdq_module( diff --git a/examples/qualcomm/executor_runner/qnn_executor_runner.cpp b/examples/qualcomm/executor_runner/qnn_executor_runner.cpp index 83478bd8e68..26e70c90f38 100644 --- a/examples/qualcomm/executor_runner/qnn_executor_runner.cpp +++ b/examples/qualcomm/executor_runner/qnn_executor_runner.cpp @@ -21,6 +21,8 @@ #include #include #include +#include +#include #include #include #include @@ -33,7 +35,6 @@ #include #include #include - static uint8_t method_allocator_pool[4 * 1024U * 1024U]; // 4 MB DEFINE_string( @@ -83,12 +84,38 @@ DEFINE_int32( 20000000, // 20MB "Size of the debug buffer in bytes to allocate for intermediate outputs and program outputs logging."); +DEFINE_string( + performance_output_path, + "inference_speed.txt", + "Records inference speed. For CI purpose."); + +DEFINE_int32( + log_level, + 0, + "Log level between 1-5, higher is more verbose. " + "This is a runtime option and will override the log level set during AOT. " + "Refer to QnnExecuTorchLogLevel under qc_compiler_spec.fbs for more info."); +DEFINE_int32( + htp_performance_mode, + 0, + "HTP Performance mode between 0-8. " + "This is a runtime option and will override the performance mode set during AOT. " + "Refer to QnnExecuTorchHtpPerformanceMode under qc_compiler_spec.fbs for more info."); +DEFINE_int32( + profile_level, + 0, + "Profile level between 0-2. " + "Level 3(Optrace) must be turned on during AOT and cannot be enabled during runtime. " + "This is a runtime option and will override the profile level set during AOT. " + "Refer to QnnExecuTorchProfileLevel under qc_compiler_spec.fbs for more info."); + using executorch::aten::Tensor; using executorch::aten::TensorImpl; using executorch::etdump::ETDumpGen; using executorch::etdump::ETDumpResult; using executorch::extension::FileDataLoader; using executorch::extension::prepare_input_tensors; +using executorch::runtime::BackendOption; using executorch::runtime::Error; using executorch::runtime::EValue; using executorch::runtime::EventTracerDebugLogLevel; @@ -151,6 +178,40 @@ int main(int argc, char** argv) { return 1; } + // Set runtime options + executorch::runtime::BackendOptions<3> backend_options; + if (!gflags::GetCommandLineFlagInfoOrDie("log_level").is_default) { + ET_LOG(Info, "Setting runtime log level: %d", FLAGS_log_level); + ET_CHECK_MSG( + backend_options.set_option(QNN_RUNTIME_LOG_LEVEL, FLAGS_log_level) == + Error::Ok, + "Failed to set backend options: %s", + QNN_RUNTIME_LOG_LEVEL); + } + if (!gflags::GetCommandLineFlagInfoOrDie("htp_performance_mode").is_default) { + ET_LOG( + Info, + "Setting runtime performance mode: %d", + FLAGS_htp_performance_mode); + ET_CHECK_MSG( + backend_options.set_option( + QNN_RUNTIME_HTP_PERFORMANCE_MODE, FLAGS_htp_performance_mode) == + Error::Ok, + "Failed to set backend options: %s", + QNN_RUNTIME_HTP_PERFORMANCE_MODE); + } + if (!gflags::GetCommandLineFlagInfoOrDie("profile_level").is_default) { + ET_LOG(Info, "Setting runtime profile level: %d", FLAGS_profile_level); + ET_CHECK_MSG( + backend_options.set_option( + QNN_RUNTIME_PROFILE_LEVEL, FLAGS_profile_level) == Error::Ok, + "Failed to set backend options: %s", + QNN_RUNTIME_PROFILE_LEVEL); + } + ET_CHECK_MSG( + set_option(QNN_BACKEND, backend_options.view()) == Error::Ok, + "Failed to set runtime options."); + // Create a loader to get the data of the program file. There are other // DataLoaders that use mmap() or point to data that's already in memory, and // users can create their own DataLoaders to load from arbitrary sources. @@ -483,10 +544,20 @@ int main(int argc, char** argv) { } ET_LOG( Info, - "%d inference took %f ms, avg %f ms", + "Total %d inference took %f ms, avg %f ms", inference_index, elapsed_time, elapsed_time / inference_index); + + // Save avg inference time for CI + std::ofstream outfile(FLAGS_performance_output_path.c_str()); + if (outfile.is_open()) { + double avg_time = elapsed_time / inference_index; + outfile << avg_time; + outfile.close(); + } else { + ET_CHECK_MSG(false, "Error saving the inference speed file"); + } } else { // if no input is provided, fill the inputs with default values auto inputs = prepare_input_tensors(*method); diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index c12cb582961..11c21af8c2c 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -104,16 +104,22 @@ def __init__( self.expected_output_shape = expected_output_shape self.extra_cmds = "" - def _adb(self, cmd): + def _adb(self, cmd, output_callback: Optional[Callable[[str], None]] = None): if not self.host_id: cmds = ["adb", "-s", self.device_id] else: cmds = ["adb", "-H", self.host_id, "-s", self.device_id] cmds.extend(cmd) - subprocess.run( - cmds, stdout=subprocess.DEVNULL if self.error_only else sys.stdout - ) + if output_callback: + result = subprocess.run( + cmds, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True + ) + output_callback(result) + else: + subprocess.run( + cmds, stdout=subprocess.DEVNULL if self.error_only else sys.stdout + ) def push(self, inputs=None, input_list=None, files=None, init_env=True): artifacts = [] @@ -173,7 +179,12 @@ def push(self, inputs=None, input_list=None, files=None, init_env=True): for file_name in files: self._adb(["push", file_name, self.workspace]) - def execute(self, custom_runner_cmd=None, method_index=0): + def execute( + self, + custom_runner_cmd=None, + method_index=0, + output_callback: Optional[Callable[[str], None]] = None, + ): self._adb(["shell", f"mkdir -p {self.output_folder}"]) # run the delegation if custom_runner_cmd is None: @@ -205,8 +216,9 @@ def execute(self, custom_runner_cmd=None, method_index=0): ) else: qnn_executor_runner_cmds = custom_runner_cmd - - self._adb(["shell", f"{qnn_executor_runner_cmds}"]) + self._adb( + ["shell", f"{qnn_executor_runner_cmds}"], output_callback=output_callback + ) def pull(self, output_path, callback=None): self._adb(["pull", "-a", self.output_folder, output_path]) diff --git a/runtime/backend/backend_init_context.h b/runtime/backend/backend_init_context.h index 5a4b70e0dbc..777744e6239 100644 --- a/runtime/backend/backend_init_context.h +++ b/runtime/backend/backend_init_context.h @@ -11,6 +11,12 @@ #include #include +#ifdef __GNUC__ +// Disable -Wdeprecated-declarations, as some builds use 'Werror'. +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#endif + namespace executorch { namespace ET_RUNTIME_NAMESPACE { /**