diff --git a/backends/qualcomm/CMakeLists.txt b/backends/qualcomm/CMakeLists.txt index 1f92b2d8cfd..3c66796594b 100644 --- a/backends/qualcomm/CMakeLists.txt +++ b/backends/qualcomm/CMakeLists.txt @@ -76,7 +76,6 @@ include_directories( set(_qnn_schema__srcs backends/qualcomm/serialization/qc_compiler_spec.fbs - backends/qualcomm/serialization/qc_binary_info.fbs ) set(_qnn_schema__include_dir "${CMAKE_BINARY_DIR}/schema/include") # Paths to headers generated from the .fbs files. @@ -116,6 +115,7 @@ add_library(qcir_utils STATIC) add_library(qnn_backend STATIC) add_library(qnn_backend_cache STATIC) add_library(qnn_context STATIC) +add_library(qnn_custom_protocol STATIC) add_library(qnn_device STATIC) add_library(qnn_executorch_backend SHARED) add_library(qnn_executorch_header INTERFACE) @@ -155,6 +155,7 @@ target_link_libraries(qnn_executorch_logging PRIVATE qnn_schema) target_link_libraries(qnn_profiler PRIVATE qnn_executorch_logging) target_link_libraries(qnn_logger PRIVATE qnn_implementation ${android_log}) target_link_libraries(qnn_backend PRIVATE qnn_implementation qnn_logger) +target_link_libraries(qnn_custom_protocol PRIVATE qcir_utils) target_link_libraries( qnn_device PRIVATE qnn_executorch_logging qnn_implementation qnn_logger ) @@ -177,7 +178,7 @@ target_link_libraries( qnn_factory PUBLIC qnn_header PRIVATE qnn_schema qnn_backend qnn_device qnn_context qnn_graph - qnn_mem_manager + qnn_mem_manager qnn_custom_protocol ) target_link_libraries( qnn_manager PRIVATE qnn_factory wrappers qnn_schema utils shared_buffer diff --git a/backends/qualcomm/aot/ir/qcir.fbs b/backends/qualcomm/aot/ir/qcir.fbs index 6c16a54e0db..dfd9bbc91e1 100755 --- a/backends/qualcomm/aot/ir/qcir.fbs +++ b/backends/qualcomm/aot/ir/qcir.fbs @@ -80,7 +80,8 @@ table Tensor { type: TensorType; dtype: DataType; qparam: QuantizeParam; - data: [ubyte]; + size: uint; + offset: ulong; } table Operator { @@ -88,9 +89,9 @@ table Operator { package_name: string; type_name: string; // keep only tensor indexes - inputs: [int]; - outputs: [int]; - params: [int]; + inputs: [uint]; + outputs: [uint]; + params: [uint]; } table Graph { diff --git a/backends/qualcomm/aot/ir/qcir_utils.cpp b/backends/qualcomm/aot/ir/qcir_utils.cpp index 8cf024ba006..48f069767bf 100755 --- a/backends/qualcomm/aot/ir/qcir_utils.cpp +++ b/backends/qualcomm/aot/ir/qcir_utils.cpp @@ -235,11 +235,8 @@ Qnn_QuantizeParams_t ToQuantizeParam(const tensor_type& tensor) { flatbuffers::Offset ToTensor( const Qnn_Tensor_t& tensor, + const uint64_t data_offset, flatbuffers::FlatBufferBuilder* builder) { - std::vector buffer( - static_cast(QNN_VER_PTR(tensor)->clientBuf.data), - static_cast(QNN_VER_PTR(tensor)->clientBuf.data) + - QNN_VER_PTR(tensor)->clientBuf.dataSize); std::vector shape( QNN_VER_PTR(tensor)->dimensions, QNN_VER_PTR(tensor)->dimensions + QNN_VER_PTR(tensor)->rank); @@ -251,10 +248,11 @@ flatbuffers::Offset ToTensor( ToTensorType(QNN_VER_PTR(tensor)->type), ToDataType(QNN_VER_PTR(tensor)->dataType), ToQuantizeParam(tensor, builder), - &buffer); + QNN_VER_PTR(tensor)->clientBuf.dataSize, + data_offset); } -Qnn_Tensor_t ToTensor(const tensor_type& tensor) { +Qnn_Tensor_t ToTensor(const tensor_type& tensor, const uint8_t* data_ptr) { auto is_io_tensor = [](Qnn_TensorType_t type) { return type < QNN_TENSOR_TYPE_STATIC; }; @@ -266,10 +264,10 @@ Qnn_Tensor_t ToTensor(const tensor_type& tensor) { QNN_VER_PTR(t)->quantizeParams = ToQuantizeParam(tensor); QNN_VER_PTR(t)->rank = tensor->shape()->size(); QNN_VER_PTR(t)->dimensions = const_cast(tensor->shape()->data()); - QNN_VER_PTR(t)->clientBuf.dataSize = tensor->data()->size(); + QNN_VER_PTR(t)->clientBuf.dataSize = tensor->size(); QNN_VER_PTR(t)->clientBuf.data = is_io_tensor(QNN_VER_PTR(t)->type) ? nullptr - : static_cast(const_cast(tensor->data()->Data())); + : static_cast(const_cast(data_ptr)); return t; } diff --git a/backends/qualcomm/aot/ir/qcir_utils.h b/backends/qualcomm/aot/ir/qcir_utils.h index 5d54eb30a69..085f09bf145 100755 --- a/backends/qualcomm/aot/ir/qcir_utils.h +++ b/backends/qualcomm/aot/ir/qcir_utils.h @@ -32,8 +32,9 @@ Qnn_QuantizeParams_t ToQuantizeParam(const tensor_type& tensor); flatbuffers::Offset ToTensor( const Qnn_Tensor_t& tensor, + const uint64_t data_offset, flatbuffers::FlatBufferBuilder* builder); -Qnn_Tensor_t ToTensor(const tensor_type& tensor); +Qnn_Tensor_t ToTensor(const tensor_type& tensor, const uint8_t* data_ptr); } // namespace qnn } // namespace backends diff --git a/backends/qualcomm/aot/python/PyQnnManagerAdaptor.h b/backends/qualcomm/aot/python/PyQnnManagerAdaptor.h index 55429f2b430..fc85f77f00f 100644 --- a/backends/qualcomm/aot/python/PyQnnManagerAdaptor.h +++ b/backends/qualcomm/aot/python/PyQnnManagerAdaptor.h @@ -8,11 +8,11 @@ #pragma once #include #include -#include #include #include #include #include +#include #include #include #include @@ -50,54 +50,92 @@ class PyQnnManager { qnn_executorch_options, qnn_executorch_context_binary_); } - // used for loading multiple graphs in qcir + // used during stage 2 of multi-graph mode explicit PyQnnManager(const py::bytes& buffer, const py::list& qcirs) : qnn_executorch_option_ptr_(buffer) { auto qnn_executorch_options = GetQnnExecuTorchOptions( qnn_executorch_option_ptr_.cast().data()); // merge multiple qcirs into one context with multiple graphs - std::vector> graphs; + + // We start retrieving tensor from offsets = 0. + std::vector offsets(1, 0); + std::vector tensor_data; + std::vector tensor_ptr; + std::vector tensor_size; + uint64_t total_tensor_size = 0; for (size_t i = 0; i < qcirs.size(); ++i) { py::buffer_info info(py::buffer(qcirs[i].cast()).request()); - flatbuffers::Verifier verifier_binary_info( - static_cast(info.ptr), - info.size * info.itemsize); - if (!qnn_delegate::VerifyBinaryInfoBuffer(verifier_binary_info)) { - QNN_EXECUTORCH_LOG_ERROR("Fail to verify binary info"); + + uint8_t* qcir_custom_buffer_ptr = static_cast(info.ptr); + QnnQcirCustomProtocol qnn_qcir_custom_protocol; + auto [status, _, qcir_tensor_size, __, qcir_tensor_ptr] = + qnn_qcir_custom_protocol.DeserializeQcirCustomBuffer( + qcir_custom_buffer_ptr); + + if (status != Error::Ok) { + QNN_EXECUTORCH_LOG_ERROR("Fail to verify QnnQcirCustomProtocol"); return; } - auto binary_info = qnn_delegate::GetBinaryInfo(info.ptr); - flatbuffers::Verifier verifier_qcir( - binary_info->data()->data(), binary_info->data()->size()); - if (!qcir::VerifyContextBuffer(verifier_qcir)) { - QNN_EXECUTORCH_LOG_ERROR("Fail to verify qcir format"); + tensor_ptr.push_back(static_cast(qcir_tensor_ptr)); + tensor_size.push_back(qcir_tensor_size); + total_tensor_size += qcir_tensor_size; + offsets.push_back(offsets.back() + qcir_tensor_size); + } + + tensor_data.resize(total_tensor_size); + + // store multiple graphs tensor in a contiguous memory space + for (size_t i = 0; i < tensor_ptr.size(); ++i) { + std::memcpy( + tensor_data.data() + offsets[i], tensor_ptr[i], tensor_size[i]); + } + + std::vector> graphs; + for (size_t i = 0; i < qcirs.size(); ++i) { + py::buffer_info info(py::buffer(qcirs[i].cast()).request()); + + uint8_t* qcir_custom_buffer_ptr = static_cast(info.ptr); + QnnQcirCustomProtocol qnn_qcir_custom_protocol; + auto [status, qcir_fbs_size, _, qcir_fbs_ptr, __] = + qnn_qcir_custom_protocol.DeserializeQcirCustomBuffer( + qcir_custom_buffer_ptr); + + if (status != Error::Ok) { + QNN_EXECUTORCH_LOG_ERROR("Fail to verify QnnQcirCustomProtocol"); return; } - auto context = qcir::GetContext(binary_info->data()->data()); + + auto context = qcir::GetContext(qcir_fbs_ptr); for (const auto& graph : *context->graphs()) { std::vector> tensors; for (const auto tensor : *graph->tensors()) { // here we need to take a detour to merge multiple qcir flatbuffers // outer ToTensor // return: flatbuffers::Offset - // consume: QnnTensor, flatbuffers::FlatBufferBuilder* + // consume: QnnTensor, data_offset, flatbuffers::FlatBufferBuilder* // inner ToTensor // return: QnnTensor - // consume: flatbuffers::Vector<::flatbuffers::Offset> - tensors.emplace_back(ToTensor(ToTensor(tensor), &builder_)); + // consume: + // flatbuffers::Vector<::flatbuffers::Offset>, + // data_ptr + tensors.emplace_back(ToTensor( + ToTensor(tensor, nullptr), + offsets[i] + tensor->offset(), + &builder_)); } std::vector> nodes; for (const auto& node : *graph->nodes()) { - int32_t* inputs_ptr = const_cast(node->inputs()->data()); - int32_t* outputs_ptr = const_cast(node->outputs()->data()); - int32_t* params_ptr = const_cast(node->params()->data()); - std::vector inputs( + uint32_t* inputs_ptr = const_cast(node->inputs()->data()); + uint32_t* outputs_ptr = + const_cast(node->outputs()->data()); + uint32_t* params_ptr = const_cast(node->params()->data()); + std::vector inputs( inputs_ptr, inputs_ptr + node->inputs()->size()); - std::vector outputs( + std::vector outputs( outputs_ptr, outputs_ptr + node->outputs()->size()); - std::vector params( + std::vector params( params_ptr, params_ptr + node->params()->size()); nodes.emplace_back(qcir::CreateOperatorDirect( builder_, @@ -118,7 +156,9 @@ class PyQnnManager { QnnExecuTorchContextBinary qcir_bin( {builder_.GetBufferPointer(), builder_.GetSize()}); - qnn_executorch_context_binary_ = MakeBinaryInfo(qcir_bin); + // Init QnnQcirCustomProtocol binary + qnn_executorch_context_binary_ = + MakeQcirCustomBinaryInfo(qcir_bin, tensor_data); qnn_manager_ = std::make_shared( qnn_executorch_options, qnn_executorch_context_binary_); } @@ -132,7 +172,7 @@ class PyQnnManager { return qnn_manager_->IsNodeSupportedByBackend(op_wrappers); } - // this method is specific for compiling multi-graphs + // this method is specific for stage 2 of compiling multi-graphs py::array_t Compile() { if (qnn_manager_->CompileQcir() != Error::Ok) { QNN_EXECUTORCH_LOG_ERROR("Fail to compile qcir"); @@ -157,26 +197,37 @@ class PyQnnManager { if (qnn_manager_->IsOnlinePrepare() || qnn_manager_->IsMultipleGraphs()) { builder_.Reset(); - std::vector> tensors; + std::vector tensor_data; + std::vector offsets; std::unordered_map tensor_map; + std::vector> fb_tensors; + std::vector> fb_ops; auto set_tensor = [&](const std::shared_ptr& wrapper, - std::vector& index) { + std::vector& index) { auto it = tensor_map.find(wrapper.get()); if (it != tensor_map.end()) { index.push_back(it->second); } else { - int i = tensors.size(); - tensor_map[wrapper.get()] = i; - index.push_back(i); - tensors.emplace_back( - ToTensor(wrapper->CloneTensorStruct(), &builder_)); + tensor_map[wrapper.get()] = fb_tensors.size(); + index.push_back(fb_tensors.size()); + offsets.push_back(tensor_data.size()); + Qnn_Tensor_t qnn_tensor = wrapper->CloneTensorStruct(); + fb_tensors.emplace_back( + ToTensor(qnn_tensor, offsets.back(), &builder_)); + uint8_t* data_ptr = + static_cast(QNN_VER_PTR(qnn_tensor)->clientBuf.data); + if (data_ptr != nullptr) { + tensor_data.insert( + tensor_data.end(), + data_ptr, + data_ptr + QNN_VER_PTR(qnn_tensor)->clientBuf.dataSize); + } } }; - std::vector> operators; for (std::shared_ptr& op_wrapper : op_wrappers) { - std::vector inputs, outputs, params; + std::vector inputs, outputs, params; for (const auto& tensor_wrapper : op_wrapper->GetInputTensors()) { set_tensor(tensor_wrapper, inputs); @@ -207,13 +258,22 @@ class PyQnnManager { static_cast(&p.scalarParam.uint8Value); QNN_VER_PTR(t)->clientBuf.dataSize = GetDataTypeSize(QNN_VER_PTR(t)->dataType); - params.push_back(tensors.size()); - tensors.emplace_back(ToTensor(t, &builder_)); + + // collect tensor data + offsets.push_back(tensor_data.size()); + const uint8_t* data_ptr = + static_cast(QNN_VER_PTR(t)->clientBuf.data); + tensor_data.insert( + tensor_data.end(), + data_ptr, + data_ptr + QNN_VER_PTR(t)->clientBuf.dataSize); + params.push_back(fb_tensors.size()); + fb_tensors.emplace_back(ToTensor(t, offsets.back(), &builder_)); } } Qnn_OpConfig_t op_config = op_wrapper->GetOpConfig(); - operators.emplace_back(qcir::CreateOperatorDirect( + fb_ops.emplace_back(qcir::CreateOperatorDirect( builder_, QNN_VER_PTR(op_config)->name, QNN_VER_PTR(op_config)->packageName, @@ -222,14 +282,22 @@ class PyQnnManager { &outputs, ¶ms)); } - auto graph = qcir::CreateGraphDirect( - builder_, graph_name.c_str(), &operators, &tensors); - std::vector> graphs({graph}); - auto context = qcir::CreateContextDirect(builder_, &graphs); + + std::vector> fb_graphs( + {qcir::CreateGraphDirect( + builder_, graph_name.c_str(), &fb_ops, &fb_tensors)}); + auto context = qcir::CreateContextDirect(builder_, &fb_graphs); builder_.Finish(context); + QnnExecuTorchContextBinary qcir_binary( {builder_.GetBufferPointer(), builder_.GetSize()}); - binary_info = MakeBinaryInfo(qcir_binary); + + custom_qcir_protocol_buffer_ = + QnnQcirCustomProtocol(qcir_binary.nbytes, tensor_data.size()); + custom_qcir_protocol_buffer_.BuildQcirCustomBuffer( + qcir_binary, tensor_data); + std::tie(binary_info.buffer, binary_info.nbytes) = + custom_qcir_protocol_buffer_.GetCustomProtocolBuffer(); } else { if (qnn_manager_->Compile(graph_name, op_wrappers) != executorch::runtime::Error::Ok) { @@ -296,41 +364,40 @@ class PyQnnManager { return qnn_manager_->GetSpillFillBufferSize(); } + QnnExecuTorchContextBinary MakeQcirCustomBinaryInfo( + const QnnExecuTorchContextBinary& ctx_bin, + const std::vector& tensor_data) { + custom_qcir_protocol_buffer_ = + QnnQcirCustomProtocol(ctx_bin.nbytes, tensor_data.size()); + custom_qcir_protocol_buffer_.BuildQcirCustomBuffer(ctx_bin, tensor_data); + auto [ptr, size] = custom_qcir_protocol_buffer_.GetCustomProtocolBuffer(); + return {ptr, size}; + } + py::array_t MakeBinaryInfo(const py::bytes& ctx_bin) { py::buffer_info info(py::buffer(ctx_bin).request()); QnnExecuTorchContextBinary binary( {info.ptr, static_cast(info.size * info.itemsize)}); - auto binary_info = MakeBinaryInfo(binary); - auto result = py::array_t(binary_info.nbytes); + + auto qnn_context_custom_protocol = QnnContextCustomProtocol(binary.nbytes); + qnn_context_custom_protocol.BuildContextCustomBuffer(binary); + auto [custom_buffer_ptr, custom_buffer_size] = + qnn_context_custom_protocol.GetCustomProtocolBuffer(); + + auto result = py::array_t(custom_buffer_size); auto result_buffer = result.request(); - std::memcpy(result_buffer.ptr, binary_info.buffer, binary_info.nbytes); + std::memcpy(result_buffer.ptr, custom_buffer_ptr, custom_buffer_size); return result; } private: - QnnExecuTorchContextBinary MakeBinaryInfo( - const QnnExecuTorchContextBinary& ctx_bin) { - auto signature = []() { - return std::to_string( - std::chrono::high_resolution_clock::now().time_since_epoch().count()); - }; - const uint8_t* base = static_cast(ctx_bin.buffer); - std::vector data(base, base + ctx_bin.nbytes); - // add signature to binary for cache reuse in runtime - builder_.Reset(); - auto binary_info = qnn_delegate::CreateBinaryInfoDirect( - builder_, signature().c_str(), &data); - builder_.Finish(binary_info); - - return QnnExecuTorchContextBinary( - {builder_.GetBufferPointer(), builder_.GetSize()}); - } - // Store the bytes object instead of a raw pointer so that this module will // keep the bytes alive. const py::bytes qnn_executorch_option_ptr_; QnnExecuTorchContextBinary qnn_executorch_context_binary_; std::shared_ptr qnn_manager_; + QnnQcirCustomProtocol custom_qcir_protocol_buffer_; + QnnContextCustomProtocol custom_context_custom_buffer_; flatbuffers::FlatBufferBuilder builder_; }; } // namespace qnn diff --git a/backends/qualcomm/aot/python/PyQnnWrapperAdaptor.cpp b/backends/qualcomm/aot/python/PyQnnWrapperAdaptor.cpp index 5e794dde323..a9bf9be7bc4 100644 --- a/backends/qualcomm/aot/python/PyQnnWrapperAdaptor.cpp +++ b/backends/qualcomm/aot/python/PyQnnWrapperAdaptor.cpp @@ -110,18 +110,6 @@ std::shared_ptr CreateTensorWrapper( std::unique_ptr quantize_param_wrapper = CreateQuantizationParamWrapper(encoding, quant_info); - if (data.size() == 0) { - return CreateTensorWrapper( - tensor_name, - tensor_type, - data_type, - std::move(quantize_param_wrapper), - rank, - dims.data(), - 0, - nullptr, - copy_data); - } return CreateTensorWrapper( tensor_name, tensor_type, @@ -130,7 +118,7 @@ std::shared_ptr CreateTensorWrapper( rank, dims.data(), 0, - data.data(), + data.size() == 0 ? nullptr : data.data(), copy_data); } diff --git a/backends/qualcomm/aot/python/targets.bzl b/backends/qualcomm/aot/python/targets.bzl index 8eb8d095c30..e1f5a6a8fc5 100644 --- a/backends/qualcomm/aot/python/targets.bzl +++ b/backends/qualcomm/aot/python/targets.bzl @@ -31,7 +31,6 @@ def define_common_targets(): "//executorch/backends/qualcomm/aot/wrappers:wrappers", "//executorch/backends/qualcomm/runtime:logging", "//executorch/backends/qualcomm:schema", - "//executorch/backends/qualcomm:qc_binary_info_schema", "//executorch/backends/qualcomm/aot/ir:qcir_utils", "//executorch/backends/qualcomm/runtime:runtime", "fbsource//third-party/qualcomm/qnn/qnn-{0}:api".format(get_qnn_library_verision()), diff --git a/backends/qualcomm/runtime/QnnExecuTorch.h b/backends/qualcomm/runtime/QnnExecuTorch.h index 4f7102dd561..2ca0cd61cd5 100644 --- a/backends/qualcomm/runtime/QnnExecuTorch.h +++ b/backends/qualcomm/runtime/QnnExecuTorch.h @@ -19,6 +19,13 @@ #ifdef __cplusplus extern "C" { #endif // __cplusplus + +// This could be: +// 1. qnn_context_binary +// 2. QnnQcirCustomProtocol +// 3. QnnContextCustomProtocol +// To check if it is custom protocol, users can deserialize the binary using +// QnnCustomProtocol and check the status typedef struct { /// qnn_context_binary_blob void* buffer; diff --git a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp index 60208afeec5..c3f67b5a576 100644 --- a/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp +++ b/backends/qualcomm/runtime/QnnExecuTorchBackend.cpp @@ -10,6 +10,7 @@ #include #include #include +#include namespace executorch { namespace backends { @@ -36,8 +37,23 @@ Result QnnExecuTorchBackend::init( QnnExecuTorchContextBinary qnn_context_blob; const qnn_delegate::QnnExecuTorchOptions* qnn_executorch_options = nullptr; - qnn_context_blob.buffer = const_cast(processed->data()); - qnn_context_blob.nbytes = processed->size(); + auto [status, signature, ctx_size, ctx_bin] = + QnnContextCustomProtocol().DeserializeContextCustomBuffer( + const_cast(processed->data())); + if (status == Error::Ok) { + QNN_EXECUTORCH_LOG_INFO( + "Deserializing processed data using QnnContextCustomProtocol"); + // After this stage, qnn_context_blob.nbytes & qnn_context_blob.buffer will + // only store qnn_context_binary. + qnn_context_blob.nbytes = ctx_size; + qnn_context_blob.buffer = ctx_bin; + } else { + // This buffer will be verified again in QnnBackendCache. + QNN_EXECUTORCH_LOG_INFO( + "Deserializing processed data using QnnQcirCustomProtocol"); + qnn_context_blob.buffer = const_cast(processed->data()); + qnn_context_blob.nbytes = processed->size(); + } // convert CompileSpec to qnn ExecuTorch option for (auto& compile_spec : compile_specs) { @@ -62,7 +78,7 @@ Result QnnExecuTorchBackend::init( // --- // check if current context binary has already been initialized // return cached one for reducing memory footprint - std::string signature = qnn_manager->GetBinarySignature(); + auto iter = delegate_map_.find(signature); if (iter != delegate_map_.end()) { QNN_EXECUTORCH_LOG_INFO( @@ -115,6 +131,13 @@ Error QnnExecuTorchBackend::execute( input_tensor_structs.reserve(input_tensors.size()); for (int i = 0; i < input_tensors.size(); ++i) { + // TODO: Enable this in future to avoid unmatch tensor size, e.g., QuantIO + // pass causing mismatch + // ET_CHECK_MSG( + // input_tensors[i]->GetBytes() == args[i]->toTensor().nbytes(), + // "Input index %d, number of bytes does not match between args and + // input_tensor, %d != %zu", i, input_tensors[i]->GetBytes(), + // args[i]->toTensor().nbytes()); if (qnn_manager->RegisterMem( args[i]->toTensor().mutable_data_ptr(), input_tensors[i]) != Error::Ok) { @@ -129,6 +152,15 @@ Error QnnExecuTorchBackend::execute( for (const auto& output_tensor : output_tensors) { // pos=0 limits the search to the prefix if (output_tensor->GetName().rfind("output_", 0) == 0) { + // TODO: Enable this in future to avoid unmatch tensor size, e.g., QuantIO + // pass causing mismatch + // ET_CHECK_MSG( + // output_tensor->GetBytes() == + // args[output_index]->toTensor().nbytes(), "Output index %d, number + // of bytes does not match between args and output_tensor, %d != %zu", + // output_index, + // output_tensor->GetBytes(), + // args[output_index]->toTensor().nbytes()); void* mutable_data_ptr = args[output_index]->toTensor().mutable_data_ptr(); if (qnn_manager->RegisterMem(mutable_data_ptr, output_tensor) != @@ -170,7 +202,7 @@ bool QnnExecuTorchBackend::is_available() const { } void QnnExecuTorchBackend::add_cached_delegate( - const std::string& signature, + const std::int64_t& signature, executorch::runtime::DelegateHandle* handle) const { std::lock_guard guard(mutex_); delegate_map_[signature] = handle; diff --git a/backends/qualcomm/runtime/QnnExecuTorchBackend.h b/backends/qualcomm/runtime/QnnExecuTorchBackend.h index 630067da48a..e83ec6b13b0 100644 --- a/backends/qualcomm/runtime/QnnExecuTorchBackend.h +++ b/backends/qualcomm/runtime/QnnExecuTorchBackend.h @@ -40,14 +40,14 @@ class QnnExecuTorchBackend final private: void add_cached_delegate( - const std::string& signature, + const std::int64_t& signature, executorch::runtime::DelegateHandle* handle) const; void erase_cached_delegate(executorch::runtime::DelegateHandle* handle) const; mutable std::mutex mutex_; - mutable std::unordered_map + mutable std::unordered_map delegate_map_; - mutable std::unordered_map + mutable std::unordered_map delegate_map_rev_; }; diff --git a/backends/qualcomm/runtime/QnnManager.cpp b/backends/qualcomm/runtime/QnnManager.cpp index a4d83585f28..f2650301a38 100644 --- a/backends/qualcomm/runtime/QnnManager.cpp +++ b/backends/qualcomm/runtime/QnnManager.cpp @@ -7,11 +7,11 @@ */ #include -#include #include #include #include #include +#include #include #include #include @@ -287,7 +287,9 @@ Error QnnManager::Init() { backend_params_ptr_ = QnnBackendFactory().Create( qnn_loaded_backend_, logger_.get(), qnn_context_blob_, options_); ET_CHECK_OR_RETURN_ERROR( - backend_params_ptr_ != nullptr, Internal, "Failed to load Qnn backend.") + backend_params_ptr_ != nullptr, + Internal, + "Failed to load Qnn backend."); ET_CHECK_OR_RETURN_ERROR( backend_params_ptr_->qnn_backend_cache_ptr_->Configure() == Error::Ok, Internal, @@ -312,6 +314,7 @@ Error QnnManager::Init() { Internal, "Fail to configure Qnn graph"); } + backend_params_ptr_->backend_init_state_ = BackendInitializeState::INITIALIZED; } @@ -488,29 +491,24 @@ Error QnnManager::GetContextBinary( } Error QnnManager::CompileQcir() { - flatbuffers::Verifier verifier_binary_info( - static_cast(qnn_context_blob_.buffer), - qnn_context_blob_.nbytes); - if (!qnn_delegate::VerifyBinaryInfoBuffer(verifier_binary_info)) { - QNN_EXECUTORCH_LOG_ERROR("Fail to verify binary info"); - return Error::Internal; - } + QnnQcirCustomProtocol qnn_qcir_custom_protocol; + auto [status, qcir_fbs_size, tensor_size, qcir_fbs_ptr, tensor_ptr] = + qnn_qcir_custom_protocol.DeserializeQcirCustomBuffer( + qnn_context_blob_.buffer); - auto binary_info = qnn_delegate::GetBinaryInfo(qnn_context_blob_.buffer); - flatbuffers::Verifier verifier_qcir( - binary_info->data()->data(), binary_info->data()->size()); - if (!qcir::VerifyContextBuffer(verifier_qcir)) { - QNN_EXECUTORCH_LOG_ERROR("Fail to verify qcir format"); + if (status != Error::Ok) { + QNN_EXECUTORCH_LOG_ERROR("Failed to verify QnnQcirCustomProtocol"); return Error::Internal; } - auto context = qcir::GetContext(binary_info->data()->data()); + auto context = qcir::GetContext(qcir_fbs_ptr); for (const auto& graph : *context->graphs()) { // qcir tensors to TensorWrapper std::vector> graph_inputs, graph_outputs, tensors; for (const auto& tensor : *graph->tensors()) { - tensors.emplace_back(CreateTensorWrapper(ToTensor(tensor))); + tensors.emplace_back(CreateTensorWrapper(ToTensor( + tensor, static_cast(tensor_ptr) + tensor->offset()))); if (tensor->type() == qcir::TensorType::WRITE) { graph_inputs.push_back(tensors.back()); } else if (tensor->type() == qcir::TensorType::READ) { @@ -544,6 +542,8 @@ Error QnnManager::CompileQcir() { const auto& tensor = graph->tensors()->Get(index); std::string name = tensor->name()->str(); Qnn_DataType_t dtype = ToDataType(tensor->dtype()); + const uint8_t* data_ptr = + static_cast(tensor_ptr) + tensor->offset(); if (tensor->shape()->size() != 0) { // add tensor param op->AddTensorParam( @@ -551,50 +551,39 @@ Error QnnManager::CompileQcir() { dtype, tensor->shape()->size(), tensor->shape()->data(), - tensor->data()->data()); + data_ptr); } else { // add scalar param switch (dtype) { case Qnn_DataType_t::QNN_DATATYPE_INT_32: op->AddScalarParam( - name, - dtype, - *reinterpret_cast(tensor->data()->Data())); + name, dtype, *reinterpret_cast(data_ptr)); break; case Qnn_DataType_t::QNN_DATATYPE_INT_16: op->AddScalarParam( - name, - dtype, - *reinterpret_cast(tensor->data()->Data())); + name, dtype, *reinterpret_cast(data_ptr)); break; case Qnn_DataType_t::QNN_DATATYPE_INT_8: - op->AddScalarParam( - name, dtype, static_cast(*tensor->data()->Data())); + op->AddScalarParam(name, dtype, static_cast(*data_ptr)); break; case Qnn_DataType_t::QNN_DATATYPE_UINT_32: op->AddScalarParam( - name, - dtype, - *reinterpret_cast(tensor->data()->Data())); + name, dtype, *reinterpret_cast(data_ptr)); break; case Qnn_DataType_t::QNN_DATATYPE_UINT_16: op->AddScalarParam( - name, - dtype, - *reinterpret_cast(tensor->data()->Data())); + name, dtype, *reinterpret_cast(data_ptr)); break; case Qnn_DataType_t::QNN_DATATYPE_UINT_8: - op->AddScalarParam(name, dtype, *tensor->data()->Data()); + op->AddScalarParam(name, dtype, *data_ptr); break; case Qnn_DataType_t::QNN_DATATYPE_FLOAT_32: case Qnn_DataType_t::QNN_DATATYPE_FLOAT_16: op->AddScalarParam( - name, - dtype, - *reinterpret_cast(tensor->data()->Data())); + name, dtype, *reinterpret_cast(data_ptr)); break; case Qnn_DataType_t::QNN_DATATYPE_BOOL_8: - op->AddScalarParam(name, dtype, *tensor->data()->Data()); + op->AddScalarParam(name, dtype, *data_ptr); break; default: QNN_EXECUTORCH_LOG_ERROR( @@ -603,15 +592,13 @@ Error QnnManager::CompileQcir() { } } } - op_wrappers.push_back(std::move(op)); + op_wrappers.emplace_back(std::move(op)); } - ET_CHECK_OR_RETURN_ERROR( Compile(graph->name()->str(), op_wrappers) == Error::Ok, Internal, "Fail to compile graph from qcir with graph_name: %s", graph->name()->str().c_str()); - ET_CHECK_OR_RETURN_ERROR( AllocateTensor(graph->name()->str(), graph_inputs, graph_outputs) == Error::Ok, @@ -672,7 +659,6 @@ Error QnnManager::Compile( return Error::Internal; } } - error = backend_params_ptr_->qnn_graph_ptr_->GraphFinalize(graph_name); if (error != QNN_SUCCESS) { QNN_EXECUTORCH_LOG_ERROR( @@ -684,15 +670,6 @@ Error QnnManager::Compile( return Error::Ok; } -std::string QnnManager::GetBinarySignature() { - flatbuffers::Verifier verifier( - static_cast(qnn_context_blob_.buffer), - qnn_context_blob_.nbytes); - return VerifyBinaryInfoBuffer(verifier) - ? GetBinaryInfo(qnn_context_blob_.buffer)->signature()->str() - : ""; -} - } // namespace qnn } // namespace backends } // namespace executorch diff --git a/backends/qualcomm/runtime/backends/CMakeLists.txt b/backends/qualcomm/runtime/backends/CMakeLists.txt index 2df806db52c..81536d26f78 100644 --- a/backends/qualcomm/runtime/backends/CMakeLists.txt +++ b/backends/qualcomm/runtime/backends/CMakeLists.txt @@ -116,6 +116,13 @@ target_sources( PRIVATE ${CMAKE_CURRENT_LIST_DIR}/QnnBackendFactory.cpp ) +# qnn_custom_protocol +target_sources( + qnn_custom_protocol + PUBLIC ${CMAKE_CURRENT_LIST_DIR}/QnnCustomProtocol.h + PRIVATE ${CMAKE_CURRENT_LIST_DIR}/QnnCustomProtocol.cpp +) + set(qnn_header_basenames QnnBackend.h QnnCommon.h diff --git a/backends/qualcomm/runtime/backends/QnnBackendCache.cpp b/backends/qualcomm/runtime/backends/QnnBackendCache.cpp index 43cb835cfff..3d5a432431c 100644 --- a/backends/qualcomm/runtime/backends/QnnBackendCache.cpp +++ b/backends/qualcomm/runtime/backends/QnnBackendCache.cpp @@ -7,9 +7,8 @@ */ #include -#include #include - +#include namespace executorch { namespace backends { namespace qnn { @@ -107,27 +106,28 @@ Error QnnBackendCache::Configure() { // DO DESERIALIZE state_ = DESERIALIZE; QNN_EXECUTORCH_LOG_INFO("Caching: Caching is in RESTORE MODE."); - flatbuffers::Verifier verifier_binary_info( - static_cast(qnn_context_blob_.buffer), - qnn_context_blob_.nbytes); - if (!qnn_delegate::VerifyBinaryInfoBuffer(verifier_binary_info)) { - QNN_EXECUTORCH_LOG_ERROR("Fail to verify binary info"); - return Error::Internal; + + auto [status, _, context_size, context_ptr] = + QnnContextCustomProtocol().DeserializeContextCustomBuffer( + qnn_context_blob_.buffer); + // For pre_gen_context.bin such as aihub + if (status == Error::Ok) { + qnn_context_blob_.buffer = context_ptr; + qnn_context_blob_.nbytes = context_size; } - auto binary_info = GetBinaryInfo(qnn_context_blob_.buffer); - Error status = GetQnnGraphInfoFromBinary( - const_cast(binary_info->data()->data()), - binary_info->data()->size()); + status = GetQnnGraphInfoFromBinary( + static_cast(qnn_context_blob_.buffer), + qnn_context_blob_.nbytes); if (status == Error::Internal) { - // check if context binary came from flatbuffer - flatbuffers::Verifier verifier( - binary_info->data()->data(), binary_info->data()->size()); - - if (qcir::VerifyContextBuffer(verifier)) { + auto [status, qcir_fbs_size, _, qcir_fbs_ptr, __] = + QnnQcirCustomProtocol().DeserializeQcirCustomBuffer( + qnn_context_blob_.buffer); + if (status == Error::Ok) { + // online prepare or first stage of multi graph state_ = ONLINE_PREPARE; - auto context = qcir::GetContext(binary_info->data()->data()); + auto context = qcir::GetContext(qcir_fbs_ptr); for (const auto& graph : *context->graphs()) { graph_names_.emplace_back(graph->name()->str()); } diff --git a/backends/qualcomm/runtime/backends/QnnContextCommon.cpp b/backends/qualcomm/runtime/backends/QnnContextCommon.cpp index 7db5164a1d5..7c66e5ad19a 100644 --- a/backends/qualcomm/runtime/backends/QnnContextCommon.cpp +++ b/backends/qualcomm/runtime/backends/QnnContextCommon.cpp @@ -7,7 +7,6 @@ */ #include - namespace executorch { namespace backends { namespace qnn { @@ -46,13 +45,13 @@ Error QnnContext::Configure() { if (cache_->GetCacheState() == QnnBackendCache::DESERIALIZE) { const QnnExecuTorchContextBinary& qnn_context_blob = cache_->GetQnnContextBlob(); - auto binary_info = GetBinaryInfo(qnn_context_blob.buffer); + error = qnn_interface.qnn_context_create_from_binary( backend_->GetHandle(), device_->GetHandle(), temp_context_config.empty() ? nullptr : temp_context_config.data(), - const_cast(binary_info->data()->data()), - binary_info->data()->size(), + static_cast(qnn_context_blob.buffer), + qnn_context_blob.nbytes, &handle_, /*profile=*/nullptr); if (error != QNN_SUCCESS) { @@ -94,9 +93,17 @@ Error QnnContext::GetContextBinary( Qnn_ErrorHandle_t error = qnn_interface.qnn_context_get_binary_size(handle_, &binary_size); if (error == QNN_SUCCESS) { - binary_buffer_.resize(binary_size); + // create our own protocol here + qnn_context_custom_protocol_ = QnnContextCustomProtocol(binary_size); + qnn_context_custom_protocol_.BuildContextCustomBuffer(); + auto [context_buffer_ptr, context_buffer_size] = + qnn_context_custom_protocol_.GetCustomProtocolBuffer(); error = qnn_interface.qnn_context_get_binary( - handle_, binary_buffer_.data(), binary_size, &bytes_written); + handle_, + static_cast(context_buffer_ptr) + + qnn_context_custom_protocol_.GetContextBinaryOffset(), + binary_size, + &bytes_written); if (error != QNN_SUCCESS) { QNN_EXECUTORCH_LOG_ERROR( "Can't get graph binary to be saved to " @@ -113,17 +120,8 @@ Error QnnContext::GetContextBinary( return Error::Internal; } - auto signature = []() { - return std::to_string(std::chrono::high_resolution_clock::now() - .time_since_epoch() - .count()); - }; - builder_.Reset(); - auto binary_info = qnn_delegate::CreateBinaryInfoDirect( - builder_, signature().c_str(), &binary_buffer_); - builder_.Finish(binary_info); - qnn_executorch_context_binary.buffer = builder_.GetBufferPointer(); - qnn_executorch_context_binary.nbytes = builder_.GetSize(); + qnn_executorch_context_binary.buffer = context_buffer_ptr; + qnn_executorch_context_binary.nbytes = context_buffer_size; } } else { QNN_EXECUTORCH_LOG_ERROR( diff --git a/backends/qualcomm/runtime/backends/QnnContextCommon.h b/backends/qualcomm/runtime/backends/QnnContextCommon.h index d93390a5379..62a0b953eec 100644 --- a/backends/qualcomm/runtime/backends/QnnContextCommon.h +++ b/backends/qualcomm/runtime/backends/QnnContextCommon.h @@ -7,10 +7,10 @@ */ #pragma once -#include #include #include #include +#include #include #include @@ -71,8 +71,7 @@ class QnnContext { QnnBackend* backend_; QnnDevice* device_; QnnBackendCache* cache_; - std::vector binary_buffer_; - flatbuffers::FlatBufferBuilder builder_; + QnnContextCustomProtocol qnn_context_custom_protocol_; }; } // namespace qnn } // namespace backends diff --git a/backends/qualcomm/runtime/backends/QnnCustomProtocol.cpp b/backends/qualcomm/runtime/backends/QnnCustomProtocol.cpp new file mode 100644 index 00000000000..6bf65f59286 --- /dev/null +++ b/backends/qualcomm/runtime/backends/QnnCustomProtocol.cpp @@ -0,0 +1,175 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace executorch { +namespace backends { +namespace qnn { + +void QnnQcirCustomProtocol::BuildQcirCustomBuffer( + const QnnExecuTorchContextBinary& qcir_binary, + const std::vector& tensor_data) { + if (qnn_custom_buffer_.size() == 0) { + uint8_t magic_number_proto_size = sizeof(magic_number_); + uint8_t qcir_fbs_proto_size = sizeof(qcir_fbs_size_); + uint8_t tensor_proto_size = sizeof(tensor_size_); + + uint64_t buffer_size = magic_number_proto_size + qcir_fbs_proto_size + + tensor_proto_size + qcir_fbs_size_ + tensor_size_; + qnn_custom_buffer_.resize(buffer_size, 0); + + size_t pos = 0; + // magic number itself + std::memcpy( + qnn_custom_buffer_.data(), &magic_number_, magic_number_proto_size); + pos += magic_number_proto_size; + + // size of qcir_fbs, should be 4 bytes + std::memcpy( + qnn_custom_buffer_.data() + pos, &qcir_fbs_size_, qcir_fbs_proto_size); + pos += qcir_fbs_proto_size; + + // size of tensor, should be 8 bytes + std::memcpy( + qnn_custom_buffer_.data() + pos, &tensor_size_, tensor_proto_size); + pos += tensor_proto_size; + + // qcir.fbs buffer + uint8_t* qcir_ptr = static_cast(qcir_binary.buffer); + + std::memcpy(qnn_custom_buffer_.data() + pos, qcir_ptr, qcir_fbs_size_); + pos += qcir_fbs_size_; + + // tensor data + std::memcpy( + qnn_custom_buffer_.data() + pos, tensor_data.data(), tensor_size_); + } +} + +std::tuple +QnnQcirCustomProtocol::DeserializeQcirCustomBuffer(void* processed_data) { + Error status = Error::Ok; + uint8_t* ptr = static_cast(processed_data); + size_t magic_number_proto_size = sizeof(magic_number_); + uint8_t qcir_fbs_proto_size = sizeof(qcir_fbs_size_); + uint8_t tensor_proto_size = sizeof(tensor_size_); + + uint32_t magic_number; + std::memcpy(&magic_number, ptr, magic_number_proto_size); + ptr += magic_number_proto_size; + + if (magic_number != magic_number_) { + QNN_EXECUTORCH_LOG_INFO( + "QnnQcirCustomProtocol expected magic number: 0x%x but get: 0x%x", + magic_number_, + magic_number); + status = Error::Internal; + } + + // Retrieve size of qcir.fbs + uint32_t qcir_fbs_size; + std::memcpy(&qcir_fbs_size, ptr, qcir_fbs_proto_size); + ptr += qcir_fbs_proto_size; + + // Retrieve size of tensor + uint64_t tensor_size; + std::memcpy(&tensor_size, ptr, tensor_proto_size); + ptr += tensor_proto_size; + + // Retrieve qcir.fbs pointer + void* qcir_fbs_ptr = static_cast(ptr); + ptr += qcir_fbs_size; + + // Retrieve tensor + void* tensor_ptr = static_cast(ptr); + + return {status, qcir_fbs_size, tensor_size, qcir_fbs_ptr, tensor_ptr}; +} + +void QnnContextCustomProtocol::BuildContextCustomBuffer() { + if (qnn_custom_buffer_.size() == 0) { + signature_ = + std::chrono::high_resolution_clock::now().time_since_epoch().count(); + + uint8_t magic_number_proto_size = sizeof(magic_number_); + uint8_t binary_proto_size = sizeof(binary_size_); + uint8_t signature_proto_size = sizeof(signature_); + uint64_t buffer_size = magic_number_proto_size + signature_proto_size + + binary_proto_size + binary_size_; + qnn_custom_buffer_.resize(buffer_size, 0); + + size_t pos = 0; + + // magic number itself + std::memcpy( + qnn_custom_buffer_.data(), &magic_number_, magic_number_proto_size); + pos += magic_number_proto_size; + + // signature itself + std::memcpy( + qnn_custom_buffer_.data() + pos, &signature_, signature_proto_size); + pos += signature_proto_size; + + // size of context binary, should be 8 bytes + // Binary itself won't be stored here. Refer to QnnCustomProtocol.h for more + // info. + std::memcpy( + qnn_custom_buffer_.data() + pos, &binary_size_, binary_proto_size); + } +} + +void QnnContextCustomProtocol::BuildContextCustomBuffer( + const QnnExecuTorchContextBinary& context_binary) { + BuildContextCustomBuffer(); + uint64_t offset = GetContextBinaryOffset(); + std::memcpy( + qnn_custom_buffer_.data() + offset, + static_cast(context_binary.buffer), + context_binary.nbytes); +} + +std::tuple +QnnContextCustomProtocol::DeserializeContextCustomBuffer(void* processed_data) { + Error status = Error::Ok; + + uint8_t* ptr = static_cast(processed_data); + uint8_t magic_number_proto_size = sizeof(magic_number_); + uint8_t binary_proto_size = sizeof(binary_size_); + uint8_t signature_proto_size = sizeof(signature_); + + uint32_t magic_number; + std::memcpy(&magic_number, ptr, magic_number_proto_size); + ptr += magic_number_proto_size; + + if (magic_number != magic_number_) { + QNN_EXECUTORCH_LOG_INFO( + "QnnContextCustomProtocol expected magic number: 0x%x but get: 0x%x", + magic_number_, + magic_number); + status = Error::Internal; + } + + std::memcpy(&signature_, ptr, signature_proto_size); + ptr += signature_proto_size; + + uint64_t binary_size; + std::memcpy(&binary_size, ptr, binary_proto_size); + ptr += binary_proto_size; + + return {status, signature_, binary_size, static_cast(ptr)}; +} + +uint64_t QnnContextCustomProtocol::GetContextBinaryOffset() { + return sizeof(magic_number_) + sizeof(signature_) + sizeof(binary_size_); +} + +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/QnnCustomProtocol.h b/backends/qualcomm/runtime/backends/QnnCustomProtocol.h new file mode 100644 index 00000000000..6ea556899f5 --- /dev/null +++ b/backends/qualcomm/runtime/backends/QnnCustomProtocol.h @@ -0,0 +1,137 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace executorch { +namespace backends { +namespace qnn { + +using executorch::runtime::Error; + +// We have 2 kinds of protocol here: custom_qcir_protocol, +// custom_context_protocol. We need this class due to limitation of 32bits +// flatbuffer. Since larger models can exceed the maximum size for 32bits +// flatbuffer, we need to define our own protocol and store some information +// outside of the flatbuffer. The magic number helps determine if we are getting +// the correct custom protocol buffer and differentiate custom_qcir_protocol +// from custom_context_protocol. +class QnnCustomProtocol { + public: + QnnCustomProtocol() {} + + // Get a pair that holds pointer pointing to the start of custom buffer and + // the size of the custom buffer. + std::pair GetCustomProtocolBuffer() { + return { + static_cast(qnn_custom_buffer_.data()), + qnn_custom_buffer_.size()}; + } + + protected: + std::vector qnn_custom_buffer_; +}; + +// For custom_qcir_protocol, we expect the following format: +// +// ------------------------------ +// | qcir magic number (4 bytes)| +// ------------------------------ +// | qcir.fbs size (4 bytes) | +// ------------------------------ +// | tensor size (8 bytes) | +// ------------------------------ +// | qcir.fbs (flatbuffer) | +// ------------------------------ +// | tensor.data | +// ------------------------------ +class QnnQcirCustomProtocol : public QnnCustomProtocol { + public: + // Constructor for Serialize + QnnQcirCustomProtocol(uint32_t qcir_fbs_size, uint64_t tensor_size) + : QnnCustomProtocol(), + qcir_fbs_size_(qcir_fbs_size), + tensor_size_(tensor_size) {} + + // Constructor for Deserialize + QnnQcirCustomProtocol() : QnnCustomProtocol() {} + + void BuildQcirCustomBuffer( + const QnnExecuTorchContextBinary& qcir_binary, + const std::vector& tensor_data); + // Return a tuple with 5 elements: + // 1) Error: Status of whether deserializing is successful. + // 2) uint32_t: Size of qcir fbs + // 3) uint64_t: Size of tensor + // 4) void*: Pointer pointing to the start of qcir fbs + // 5) void*: Pointer pointing to the start of tensor + std::tuple + DeserializeQcirCustomBuffer(void* processed_data); + + private: + static constexpr uint32_t magic_number_ = 0x1234ABCD; + uint32_t qcir_fbs_size_{0}; + uint64_t tensor_size_{0}; +}; + +// For custom context binary protocol, we expect the following format: +// +// --------------------------------- +// | magic number (4 bytes) | +// --------------------------------- +// | signature (8 bytes) | +// --------------------------------- +// | context_binary_size (8 bytes) | +// --------------------------------- +// | context_binary.data | +// --------------------------------- +class QnnContextCustomProtocol : public QnnCustomProtocol { + public: + // Constructor for Serialize + QnnContextCustomProtocol(uint64_t binary_size) + : QnnCustomProtocol(), binary_size_(binary_size) {} + + // Constructor for Deserialize + QnnContextCustomProtocol() : QnnCustomProtocol() {} + + // Please note that this function will only initialize the required memory + // space and fill in all meta data except for context_binary.data. Users will + // need to handle context_binary.data themselves. This is because QNN-provided + // functions, such as qnn_context_get_binary(), ask for a memory address + // to store data and will fill it in for us. + void BuildContextCustomBuffer(); + // Use this function if you already have context_binary ahead of time. + void BuildContextCustomBuffer(const QnnExecuTorchContextBinary& qcir_binary); + // Return a tuple with 4 elements: + // 1) Error: Status of whether deserializing is successful. + // 2) int64_t: Graph signature + // 3) uint64_t: Size of the context binary + // 4) void*: Pointer pointing to the start of context_binary + std::tuple DeserializeContextCustomBuffer( + void* processed_data); + uint64_t GetContextBinaryOffset(); + + private: + static constexpr uint32_t magic_number_ = 0x5678ABCD; + int64_t signature_{0}; + uint64_t binary_size_{0}; +}; + +} // namespace qnn +} // namespace backends +} // namespace executorch diff --git a/backends/qualcomm/runtime/backends/QnnLogger.cpp b/backends/qualcomm/runtime/backends/QnnLogger.cpp index 412b1a2db2c..5b86894d874 100644 --- a/backends/qualcomm/runtime/backends/QnnLogger.cpp +++ b/backends/qualcomm/runtime/backends/QnnLogger.cpp @@ -10,7 +10,6 @@ #include #include -#include #include #include "QnnLog.h" diff --git a/backends/qualcomm/runtime/targets.bzl b/backends/qualcomm/runtime/targets.bzl index be4c56b587d..febe35b4e79 100644 --- a/backends/qualcomm/runtime/targets.bzl +++ b/backends/qualcomm/runtime/targets.bzl @@ -30,7 +30,6 @@ def define_common_targets(): exported_deps = [ "fbsource//third-party/toolchains:log", "//executorch/backends/qualcomm:schema", - "//executorch/backends/qualcomm:qc_binary_info_schema", "//executorch/runtime/core:core", ], ) @@ -69,7 +68,6 @@ def define_common_targets(): "fbsource//third-party/qualcomm/qnn/qnn-{0}:api".format(get_qnn_library_verision()), ":logging", "//executorch/backends/qualcomm:schema", - "//executorch/backends/qualcomm:qc_binary_info_schema", "//executorch/backends/qualcomm/aot/ir:qcir_utils", "//executorch/backends/qualcomm/aot/wrappers:wrappers", "//executorch/runtime/backend:interface", diff --git a/backends/qualcomm/serialization/qc_binary_info.fbs b/backends/qualcomm/serialization/qc_binary_info.fbs deleted file mode 100644 index 3f301055269..00000000000 --- a/backends/qualcomm/serialization/qc_binary_info.fbs +++ /dev/null @@ -1,20 +0,0 @@ -//============================================================================ -// -// Copyright (c) Qualcomm Innovation Center, Inc. -// All rights reserved -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. -// -//============================================================================ - -namespace qnn_delegate; - -table BinaryInfo { - // Signature of binary - signature: string; - // Data of processed binary - data: [ubyte]; -} - -root_type BinaryInfo; diff --git a/backends/qualcomm/targets.bzl b/backends/qualcomm/targets.bzl index 521152d2791..fbbfa0f1925 100644 --- a/backends/qualcomm/targets.bzl +++ b/backends/qualcomm/targets.bzl @@ -16,12 +16,6 @@ SCHEMA_GEN_RULE_NAME = "qc_compiler_spec_generated" SCHEMA_LIRRARY_NAME = SCHEMA_NAME -QC_BINARY_INFO_SCHEMA = "qc_binary_info" -QC_BINARY_INFO_INPUT_SCHEMA = "serialization/" + QC_BINARY_INFO_SCHEMA + ".fbs" -QC_BINARY_INFO_SCHEMA_GEN_RULE_NAME = QC_BINARY_INFO_SCHEMA + "_generated" -QC_BINARY_INFO_OUTPUT_SCHEMA_HEADER = QC_BINARY_INFO_SCHEMA_GEN_RULE_NAME + ".h" -QC_BINARY_INFO_SCHEMA_LIRRARY_NAME = QC_BINARY_INFO_SCHEMA - def generate_schema_header(rule_name, srcs, headers, default_header): """Generate header file given flatbuffer schema """ @@ -83,33 +77,6 @@ def define_common_targets(): platforms = [ANDROID], ) - generate_schema_header( - QC_BINARY_INFO_SCHEMA_GEN_RULE_NAME, - [QC_BINARY_INFO_INPUT_SCHEMA], - [QC_BINARY_INFO_OUTPUT_SCHEMA_HEADER], - QC_BINARY_INFO_OUTPUT_SCHEMA_HEADER, - ) - - runtime.cxx_library( - name = "qc_binary_info_schema", - srcs = [], - visibility = [ - # Lock this down as tightly as possible to ensure that flatbuffers - # are an implementation detail. Ideally this list would only include - # //executorch/runtime/executor/... - "//executorch/codegen/tools/...", - "//executorch/runtime/executor/...", - "//executorch/backends/qualcomm/...", - "//executorch/backends/qualcomm/runtime/...", - ], - exported_headers = { - QC_BINARY_INFO_OUTPUT_SCHEMA_HEADER: ":{}[{}]".format( QC_BINARY_INFO_SCHEMA_GEN_RULE_NAME, QC_BINARY_INFO_OUTPUT_SCHEMA_HEADER), - }, - exported_external_deps = ["flatbuffers-api"], - define_static_target = True, - platforms = [ANDROID], - ) - runtime.cxx_library( name = "qnn_executorch_backend", srcs = [], diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index f9550d64832..4b489ea5157 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -68,7 +68,8 @@ from executorch.examples.models.mobilenet_v2 import MV2Model from executorch.examples.models.mobilenet_v3 import MV3Model from executorch.examples.models.torchvision_vit.model import TorchVisionViTModel -from executorch.examples.models.wav2letter import Wav2LetterModel + +# from executorch.examples.models.wav2letter import Wav2LetterModel from executorch.exir import to_edge from executorch.exir.backend.backend_api import disable_validation from executorch.exir.passes import PassManager @@ -679,7 +680,8 @@ def test_qnn_backend_example_models(self): MV3Model(), MobileBertModelExample(), TorchVisionViTModel(), - Wav2LetterModel(), + # Encountered undefined symbol in mainline. Reopen once resolved. + # Wav2LetterModel(), ] expected_partitions = [ 1, @@ -1490,11 +1492,12 @@ def test_qnn_backend_example_models(self): QCOM_ANNOTATION: (), QCOM_QUANT_DTYPE: QuantDtype.use_8a8w, }, - { - QCOM_MODULE: Wav2LetterModel(), - QCOM_ANNOTATION: (), - QCOM_QUANT_DTYPE: QuantDtype.use_8a8w, - }, + # Encountered undefined symbol in mainline. Reopen once resolved. + # { + # QCOM_MODULE: Wav2LetterModel(), + # QCOM_ANNOTATION: (), + # QCOM_QUANT_DTYPE: QuantDtype.use_8a8w, + # }, ] expected_partitions = [ 1, @@ -1507,7 +1510,7 @@ def test_qnn_backend_example_models(self): # For MobileBertModelExample # 1, 1, - 1, + # 1, For Wav2LetterModel ] # TODO: Due to trigger maximum recursion depth exceeded, need to check it. disable_validation() @@ -1653,11 +1656,7 @@ def test_qnn_backend_multi_graphs(self): for i, edge_prog in enumerate(edge_progs) ] prog_mgr = generate_multi_graph_program( - compiler_specs=compiler_specs[0], - processed_bytes=[ - prog.graph_module.lowered_module_0.processed_bytes - for prog in exported_programs - ], + compiler_specs=compiler_specs[0], exported_programs=exported_programs ) for index, module in enumerate(modules): self.verify_output( @@ -2123,10 +2122,7 @@ def test_qnn_backend_multi_graphs(self): ] prog_mgr = generate_multi_graph_program( compiler_specs=compiler_specs[0], - processed_bytes=[ - prog.graph_module.lowered_module_0.processed_bytes - for prog in exported_programs - ], + exported_programs=exported_programs, ) for index, module in enumerate(modules): self.verify_output( @@ -3366,6 +3362,7 @@ def test_ptq_mobilebert(self): for k, v in cpu.items(): self.assertLessEqual(abs(v[0] - htp[k][0]), 5) + @unittest.skip("encountered undefined symbol in mainline, reopen once resolved") def test_wav2letter(self): if not self.required_envs([self.pretrained_weight]): self.skipTest("missing required envs") diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index a73fe6944eb..2e0ee4f7c63 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -5,9 +5,10 @@ # LICENSE file in the root directory of this source tree. import operator +import re import warnings from collections import OrderedDict -from typing import Callable, Dict, FrozenSet, List, Tuple +from typing import Any, Callable, Dict, FrozenSet, List, Optional, Tuple import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor @@ -647,7 +648,13 @@ def op_impl(inputs: List[torch.Tensor]): for v in outputs.values() ) - def build_graph(inputs, outputs): + def build_graph( + inputs, + outputs, + qnn_in_order: Optional[List[int]] = None, + executorch_in_order: Optional[List[int]] = None, + executorch_out_order: Optional[List[int]] = None, + ): # custom op declaration inputs_str = "Tensor[] inputs" func_proto = f"{op_name}({inputs_str}) -> Any" @@ -658,13 +665,39 @@ def build_graph(inputs, outputs): # model architecture mimicking context binary class Model(torch.nn.Module): - def forward(self, *inputs): - return getattr( + """ + The args of forward() can be thought of as what executorch is accepting as input. + The getattr inside the forward() can be thought of as qnn context binary. + When we first pass in the input, we need to use the executorch's(nn.module) input order. + After we get into forward(), we then need to convert input order to qnn's input order. + Same as return, when qnn returns the value, we need to reorder them back to executorh's output order. + """ + + def __init__(self, qnn_in_order, executorch_out_order): + super().__init__() + self.qnn_in_order = qnn_in_order + self.executorch_out_order = executorch_out_order + + def forward(self, *inputs): # executorch + if self.qnn_in_order: + inputs = tuple(inputs[i] for i in self.qnn_in_order) + ret = getattr( getattr(torch.ops, OpContextLoader.namespace), op_name ).default(inputs) + return ( + [ret[idx] for idx in self.executorch_out_order] + if self.executorch_out_order + else ret + ) + + inputs = ( + tuple(tuple(inputs.values())[i] for i in executorch_in_order) + if executorch_in_order + else tuple(inputs.values()) + ) - model = Model() - prog = torch.export.export(model, tuple(inputs.values()), strict=True) + model = Model(qnn_in_order, executorch_out_order) + prog = torch.export.export(model, inputs, strict=True) # bookkeeping for variables' life cycle return { "custom_op": custom_op, @@ -707,6 +740,7 @@ def preprocess_binary(ctx_bin, compiler_specs): for k, v in type_map.items(): dtype_map.setdefault(v, k) + qnn_in_order, executorch_in_order, executorch_out_order = [], [], [] if custom_info is not None: # since some context binaries might fail to open on host # if they are compiled with special flags: @@ -714,6 +748,9 @@ def preprocess_binary(ctx_bin, compiler_specs): # use custom information here instead inputs = build_tensor(custom_info["graph_inputs"], dtype_map) outputs = build_tensor(custom_info["graph_outputs"], dtype_map) + qnn_in_order = custom_info["qnn_in_order"] + executorch_in_order = custom_info["executorch_in_order"] + executorch_out_order = custom_info["executorch_out_order"] graph_name = custom_info["graph_name"] else: # get context-binary io tensor info through qnn manager @@ -728,15 +765,21 @@ def preprocess_binary(ctx_bin, compiler_specs): inputs = build_tensor(qnn_mgr.GetGraphInputs(graph_name), dtype_map) outputs = build_tensor(qnn_mgr.GetGraphOutputs(graph_name), dtype_map) qnn_mgr.Destroy() - # generate graph specific for loading context - bundle_prog = build_graph(inputs, outputs) + bundle_prog = build_graph( + inputs, outputs, qnn_in_order, executorch_in_order, executorch_out_order + ) bundle_prog.update({"inputs": inputs, "outputs": outputs}) + + # TODO: to_edge() decorator alters the function call behavior, which + # requires "self" when calling. To work around this issue, + # temporarily remove the first parameter name. edge_prog_mgr = to_edge( - programs={graph_name: bundle_prog["exported_program"]}, + {graph_name: bundle_prog["exported_program"]}, # do not alter name for custom op compile_config=EdgeCompileConfig(_use_edge_ops=False), ) + # update meta with context binary for n in edge_prog_mgr._edge_programs[graph_name].graph.nodes: if n.op == "call_function" and OpContextLoader.namespace in str(n.target): @@ -757,11 +800,23 @@ def draw_graph(title, path, graph_module: torch.fx.GraphModule): def generate_multi_graph_program( compiler_specs: List[CompileSpec], - processed_bytes: List[bytes], + exported_programs: List[ExportedProgram] = None, backend_config: ExecutorchBackendConfig = None, + constant_methods: Optional[Dict[str, Any]] = None, ) -> ExecutorchProgramManager: + # compile multiple graphs in qcir into single context binary - graph_inputs, graph_outputs = {}, {} + ( + graph_inputs, + graph_outputs, + qnn_in_order, + executorch_in_order, + executorch_out_order, + ) = ({}, {}, {}, {}, {}) + + processed_bytes = [ + prog.graph_module.lowered_module_0.processed_bytes for prog in exported_programs + ] qnn_mgr = PyQnnManagerAdaptor.QnnManager( generate_qnn_executorch_option(compiler_specs), processed_bytes ) @@ -772,6 +827,41 @@ def generate_multi_graph_program( for graph_name in graph_names: graph_inputs[graph_name] = qnn_mgr.GetGraphInputs(graph_name) graph_outputs[graph_name] = qnn_mgr.GetGraphOutputs(graph_name) + + # We need to obtain the order of the IOs to correctly map QNN with nn.module + for i, graph_name in enumerate(graph_names): + # input + input_names = [ + node.name + for node in exported_programs[i].graph_module.graph.nodes + if node.op == "placeholder" + ] + qnn_input_names = [wrapper.GetName() for wrapper in graph_inputs[graph_name]] + input_order_list = [] + for input_name in input_names: + # e.g., input_0_tokens_0 + pattern = rf"^input_(\d+)_({input_name})_(\d+)$" + for j in range(len(qnn_input_names)): + if re.match(pattern, qnn_input_names[j]): + input_order_list.append(j) + break + assert ( + len(input_order_list) == len(input_names) == len(qnn_input_names) + ), "Order list length is different from names" + executorch_in_order[graph_name] = input_order_list + qnn_in_order[graph_name] = sorted( + range(len(input_order_list)), key=lambda k: input_order_list[k] + ) + + # output + get_item_list = [ + node + for node in exported_programs[i].graph_module.graph.nodes + if node.op == "output" + ][0].args[0] + output_order_list = [item.args[1] for item in get_item_list] + executorch_out_order[graph_name] = output_order_list + qnn_mgr.Destroy() # build custom ops with different graph signatures @@ -785,16 +875,20 @@ def generate_multi_graph_program( "graph_inputs": graph_inputs[graph_name], "graph_outputs": graph_outputs[graph_name], "graph_name": graph_name, + "qnn_in_order": qnn_in_order[graph_name], + "executorch_in_order": executorch_in_order[graph_name], + "executorch_out_order": executorch_out_order[graph_name], }, ) for graph_name in graph_names ] # leverage ExecutorchProgramManager for generating pte with multi-methods edge_prog_mgr = to_edge( - programs={ + { graph_name: bundle_prog["exported_program"] for graph_name, bundle_prog in zip(graph_names, bundle_progs) }, + constant_methods=constant_methods, # do not alter name for custom op compile_config=EdgeCompileConfig(_use_edge_ops=False), ) @@ -805,7 +899,8 @@ def generate_multi_graph_program( n.meta[OpContextLoader.meta_ctx_bin] = binary_info break - return edge_prog_mgr.to_backend(QnnPartitioner(compiler_specs)).to_executorch( + edge_prog_mgr = edge_prog_mgr.to_backend(QnnPartitioner(compiler_specs)) + return edge_prog_mgr.to_executorch( config=backend_config or ExecutorchBackendConfig() ) diff --git a/examples/qualcomm/oss_scripts/llama2/llama.py b/examples/qualcomm/oss_scripts/llama2/llama.py index 55f84bbcaba..2a2968362ac 100755 --- a/examples/qualcomm/oss_scripts/llama2/llama.py +++ b/examples/qualcomm/oss_scripts/llama2/llama.py @@ -340,7 +340,7 @@ def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type): ] * (self.llama_meta["get_max_seq_len"] - 1): a.meta[QCOM_QUANTIZED_IO] = kv_type - def quantize(self, quant_dtype, custom_annotations=()): + def quantize(self, quant_dtype, args, custom_annotations=()): self.quant_dtype = quant_dtype quantizer = make_quantizer( quant_dtype=quant_dtype, @@ -472,6 +472,7 @@ def compile(args): start_quantize_ts = time.time() single_llama.quantize( quant_dtype, + args=args, custom_annotations=( annotate_matmul_16a8w, annotate_linear_16a8w_in_affine_layer, diff --git a/examples/qualcomm/oss_scripts/llama3_2/README.md b/examples/qualcomm/oss_scripts/llama3_2/README.md new file mode 100644 index 00000000000..51de982b1b1 --- /dev/null +++ b/examples/qualcomm/oss_scripts/llama3_2/README.md @@ -0,0 +1,39 @@ +# Summary + +## Overview +This file provides instructions to run LLAMA3.2 1B and 3B (WIP) with different parameters via the Qualcomm HTP backend. In LLAMA3.2, we offer the following modes to execute the model: + +Prefill Mode: This is also known as batch prefill mode, where the model takes in a list of tokens as input and generates the next token along with the key-value (KV) cache for all tokens. This mode is efficient for generating the initial sequence of tokens (usually the user's prompt). + +KV Cache Mode: In KV Cache mode, the model takes in a single previous token and generates the next predicted token along with its KV cache. It is efficient for generating subsequent tokens after the initial prompt. + +Hybrid Mode: Hybrid mode leverages the strengths of both batch prefill and KV cache modes to optimize token generation speed. Initially, it uses prefill mode to efficiently generate the prompt's key-value (KV) cache. Then, the mode switches to KV cache mode, which excels at generating subsequent tokens. + +## Instructions +### Note +1. For hybrid mode, the export time will be longer and can take up to 2-4 hours to complete. +2. When exporting a hybrid mode model, please ensure the device has at least 80 GB of memory and swap space. + +### Step 1: Setup +1. Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch. +2. Follow the [tutorial](https://pytorch.org/executorch/stable/build-run-qualcomm-ai-engine-direct-backend.html) to build Qualcomm AI Engine Direct Backend. + +### Step 2: Prepare Model +1. Follow the [instructions](https://www.llama.com/) to download models. +At the end of this step, users should have the following files ready: consolidated.00.pth, params.json, and tokenizer.model. + +### Step3: Run default examples using hybrid mode. +Default example using hybrid mode. +```bash +python examples/qualcomm/oss_scripts/llama3_2/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --prompt "what is 1+1" --temperature 0 --model_size 1B --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 +``` + +If you would like to compile the model only, we have provided the flag `--compile_only`. +```bash +python examples/qualcomm/oss_scripts/llama3_2/llama.py -b build-android -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --prompt "what is 1+1" --temperature 0 --model_size 1B --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --compile_only +``` + +On the other hand, if you already have a pre-compiled .pte model, you can perform inference by providing the flag `--pre_gen_pte` and specifying the folder that contains the .pte model. +```bash +python examples/qualcomm/oss_scripts/llama3_2/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --prompt "what is 1+1" --temperature 0 --model_size 1B --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --pre_gen_pte ${FOLDER_TO_PRE_GEN_PTE} +``` \ No newline at end of file diff --git a/examples/qualcomm/oss_scripts/llama3_2/llama.py b/examples/qualcomm/oss_scripts/llama3_2/llama.py index 72d4a905c06..13fb99a4202 100755 --- a/examples/qualcomm/oss_scripts/llama3_2/llama.py +++ b/examples/qualcomm/oss_scripts/llama3_2/llama.py @@ -7,6 +7,7 @@ # TODO: reenable pyre after fixing the issues # pyre-ignore-all-errors +import copy import getpass import json import logging @@ -23,7 +24,6 @@ from executorch.backends.qualcomm.quantizer.custom_annotation import ( annotate_matmul_16a8w, - custom_annotate_llama_last_conv_16a8w, ) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype @@ -33,6 +33,7 @@ capture_program, convert_linear_to_conv2d, generate_htp_compiler_spec, + generate_multi_graph_program, generate_qnn_executorch_compiler_spec, get_soc_to_chipset_map, ) @@ -47,6 +48,7 @@ SimpleADB, ) from executorch.exir import EdgeCompileConfig, EdgeProgramManager +from executorch.exir.backend.backend_api import to_backend from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass @@ -62,8 +64,6 @@ logging.basicConfig(level=logging.INFO, format=FORMAT) logging.getLogger().setLevel(logging.INFO) -pte_filename = "llama3_2_qnn" - def _kv_calibrate( example_inputs, @@ -106,7 +106,7 @@ def _kv_calibrate( print(f"calibration data:\n{sp_model.decode(token_list)}") -def _batch_prefill_calibrate( +def _prefill_calibrate( example_inputs, user_prompts, module: torch.fx.GraphModule, @@ -150,7 +150,7 @@ def calibrate( max_seq_len=512, ): if len(example_inputs) == 2: - _batch_prefill_calibrate( + _prefill_calibrate( example_inputs, user_prompts, module, @@ -170,12 +170,13 @@ def calibrate( class SingleLlama: - def __init__(self, llama_model) -> None: + def __init__(self, llama_model, pte_filename) -> None: super().__init__() self.llama_model = llama_model self.quant_dtype = None self.llama_meta = self.llama_model.get_metadata() self.has_quant_io = False + self.pte_filename = pte_filename if self.llama_meta["get_use_kv_cache"]: tokens, atten_mask, pos_ids, k_caches, v_caches = self.get_example_inputs( use_kv_cache=True @@ -209,7 +210,7 @@ def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type, sharding_type): == self.llama_meta["get_head_dim"] ): a.meta[QCOM_QUANTIZED_IO] = kv_type - # single head, batch_prefill mode + # single head, prefill mode elif a.meta["val"].flatten().size()[0] == self.llama_meta[ "get_head_dim" ] * (self.llama_meta["get_max_seq_len"] - 1): @@ -240,13 +241,12 @@ def quantize(self, quant_dtype, args, custom_annotations=()): ).module() fx_graph_module = prepare_pt2e(fx_graph_module, quantizer) logging.info("Quantizing the model...") - calibrate( self.get_example_inputs(self.llama_meta["get_use_kv_cache"]), args.prompt, fx_graph_module, tokenizer_model_path=args.tokenizer_model, - max_seq_len=args.seq_len, + max_seq_len=self.llama_meta["get_max_seq_len"], ) self.llama_model = convert_pt2e(fx_graph_module) @@ -280,7 +280,7 @@ def lowering_modules( compiler_specs = generate_qnn_executorch_compiler_spec( soc_model=soc_model, backend_options=backend_options, - shared_buffer=True, + shared_buffer=False, ) skip_node_op_set = {"llama.fallback.default"} partitioner = QnnPartitioner( @@ -309,56 +309,68 @@ def lowering_modules( ) edge_prog_mgr = edge_prog_mgr.to_backend(partitioner) exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config) - with open(f"{work_space}/{pte_filename}.pte", "wb") as file: + with open(f"{work_space}/{self.pte_filename}.pte", "wb") as file: exec_prog_mgr.write_to_file(file) def get_example_inputs(self, use_kv_cache=True): return self.llama_model.get_example_inputs(use_kv_cache) -def compile(args): +def compile(args, pte_filename): os.makedirs(args.artifact, exist_ok=True) start_ts = time.time() - if args.model_mode == "kv": - use_kv_cache = output_new_cache_only = True - matmul_annotate_func = partial(annotate_matmul_16a8w, traverse_input1=True) - elif args.model_mode == "batch_prefill": - use_kv_cache = output_new_cache_only = False - matmul_annotate_func = partial(annotate_matmul_16a8w, traverse_input1=False) - elif args.model_mode == "hybrid": - raise NotImplementedError( - f"model_mode {args.model_mode} is not implemented yet." - ) - else: - raise RuntimeError(f"No such model_mode {args.model_mode}.") - with open(args.params) as f: - config = ModelArgs(**json.load(f)) + kv_config = ModelArgs(**json.load(f)) # TODO: support batch inputs if necessary - config.max_batch_size = 1 - config.max_seq_len = args.seq_len - config.use_kv_cache = use_kv_cache + kv_config.max_batch_size = 1 + kv_config.max_seq_len = args.kv_seq_len + kv_config.use_kv_cache = True + + prefill_config = copy.copy(kv_config) + prefill_config.max_seq_len = args.prefill_seq_len + prefill_config.use_kv_cache = False + state_dict = torch.load( args.checkpoint, weights_only=True, map_location="cpu", mmap=True ) - llama_instance = None + llama_instance_list = [] with torch.device("meta"): - llama_instance = LlamaModel(config, output_new_cache_only=output_new_cache_only) + if args.model_mode == "kv": + llama_instance_list.append( + LlamaModel(kv_config, output_new_cache_only=True) + ) + elif args.model_mode == "prefill": + llama_instance_list.append( + LlamaModel(prefill_config, output_new_cache_only=False) + ) + elif args.model_mode == "hybrid": + llama_instance_list.append( + LlamaModel(prefill_config, output_new_cache_only=False) + ) + llama_instance_list.append( + LlamaModel(kv_config, output_new_cache_only=True) + ) + else: + raise RuntimeError(f"No such model_mode {args.model_mode}.") + if "model" in state_dict: state_dict = state_dict["model"] - llama_instance.load_state_dict( - state_dict, - strict=False, - assign=True, - ) + + for llama_instance in llama_instance_list: + llama_instance.load_state_dict( + state_dict, + strict=False, + assign=True, + ) end_load_ts = time.time() logging.info(f"Time for loading checkpoint: {end_load_ts - start_ts}") - for layer in llama_instance.layers: - if getattr(layer.attention, "prepare_sha", None): - layer.attention.prepare_sha() + for llama_instance in llama_instance_list: + for layer in llama_instance.layers: + if getattr(layer.attention, "prepare_sha", None): + layer.attention.prepare_sha() use_fp16 = False if args.ptq != None: @@ -381,61 +393,137 @@ def compile(args): if args.dtype_override is not None: dtype_override = DType[args.dtype_override] - llama_instance = llama_instance.to(dtype_override.to_torch_dtype()) + for i in range(len(llama_instance_list)): + llama_instance_list[i] = llama_instance_list[i].to( + dtype_override.to_torch_dtype() + ) - llama_instance = convert_linear_to_conv2d(llama_instance) - single_llama = SingleLlama(llama_instance.eval()) + for i in range(len(llama_instance_list)): + llama_instance_list[i] = convert_linear_to_conv2d(llama_instance_list[i]) + llama_instance_list[i] = SingleLlama( + llama_instance_list[i].eval(), pte_filename + ) if args.ptq != None: start_quantize_ts = time.time() - single_llama.quantize( - quant_dtype=quant_dtype, - args=args, - custom_annotations=( - custom_annotate_llama_last_conv_16a8w, - matmul_annotate_func, - ), - ) + for llama_instance in llama_instance_list: + llama_instance.quantize( + quant_dtype=quant_dtype, + args=args, + custom_annotations=( + partial( + annotate_matmul_16a8w, + traverse_input1=llama_instance.llama_meta["get_use_kv_cache"], + ), + ), + ) end_quantize_ts = time.time() logging.info(f"Time for quantizing: {end_quantize_ts - start_quantize_ts}") start_lowering_ts = time.time() - single_llama.lowering_modules( - args.artifact, - kv_type=kv_type, - sharding_type=sharding_type, - use_fp16=use_fp16, - soc_model=get_soc_to_chipset_map()[args.model], - num_sharding=args.num_sharding, - ) + + if len(llama_instance_list) == 1: + llama_instance_list[0].lowering_modules( + args.artifact, + kv_type=kv_type, + sharding_type=sharding_type, + use_fp16=use_fp16, + soc_model=get_soc_to_chipset_map()[args.model], + num_sharding=args.num_sharding, + ) + else: + sample_inputs_list = [ + llama_instace.inputs for llama_instace in llama_instance_list + ] + edge_progs = [ + capture_program(llama_instance.llama_model, sample_input) + for llama_instance, sample_input in zip( + llama_instance_list, sample_inputs_list + ) + ] + + if args.num_sharding > 0: + for i in range(len(llama_instance_list)): + model_sharding.split_graph( + edge_progs[i].exported_program, + llama_instance_list[i].llama_meta["get_n_layers"], + shares=args.num_sharding, + ) + + for i in range(len(llama_instance_list)): + llama_instance_list[i]._tag_kv_ios( + edge_progs[i].exported_program.graph_module, + kv_type=kv_type, + sharding_type=sharding_type, + ) + backend_options = generate_htp_compiler_spec(use_fp16=use_fp16) + graph_names = ["prefill_forward", "kv_forward"] + compiler_specs = [ + generate_qnn_executorch_compiler_spec( + soc_model=get_soc_to_chipset_map()[args.model], + backend_options=backend_options, + shared_buffer=True, + multiple_graphs=True, + graph_name=graph_name, + ) + for graph_name in graph_names + ] + exported_programs = [ + to_backend(edge_prog.exported_program, QnnPartitioner(compiler_specs[i])) + for i, edge_prog in enumerate(edge_progs) + ] + + executorch_config = ExecutorchBackendConfig( + passes=[ + BuildQuantIo(), + ], + # For shared buffer, user must pass the memory address + # which is allocated by RPC memory to executor runner. + # Therefore, won't want to pre-allocate + # by memory manager in runtime. + memory_planning_pass=MemoryPlanningPass( + alloc_graph_input=False, + alloc_graph_output=False, + ), + extract_delegate_segments=True, + ) + + prog_mgr = generate_multi_graph_program( + compiler_specs=compiler_specs[0], + exported_programs=exported_programs, + backend_config=executorch_config, + constant_methods=llama_instance_list[1].llama_meta, # kv method meta + ) + with open(f"{args.artifact}/{pte_filename}.pte", "wb") as file: + prog_mgr.write_to_file(file) + end_lowering_ts = time.time() logging.info(f"Time for compiling: {end_lowering_ts - start_lowering_ts}") -def inference(args, pre_gen_pte=""): +def inference(args, pte_filename, pre_gen_pte=""): workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/single_llama" - if args.model_mode == "batch_prefill": + if args.model_mode == "prefill": eval_mode = 0 elif args.model_mode == "kv": eval_mode = 1 elif args.model_mode == "hybrid": eval_mode = 2 - raise NotImplementedError( - f"model_mode {args.model_mode} is not implemented yet." - ) else: raise RuntimeError(f"No such model_mode {args.model_mode}.") + seq_len = args.prefill_seq_len if args.model_mode == "prefill" else args.kv_seq_len runner_args = " ".join( [ f"--model_path {pte_filename}.pte", "--output_path outputs/outputs.txt", f"--tokenizer_path {os.path.basename(args.tokenizer_model)}", f'--prompt "{args.prompt}"', - f"--seq_len {args.seq_len}", + f"--seq_len {seq_len}", f"--eval_mode {eval_mode}", f"--temperature {args.temperature}", + f"--system_prompt '{args.system_prompt}'", ] ) runner_cmd = " ".join( @@ -544,10 +632,10 @@ def main(): ) parser.add_argument( - "--seq_len", - help="Ouput sequence length for llama.", - default=128, - type=int, + "--system_prompt", + help="Tells the model what kind of assistant it should be. For example, You are a helpful AI assistant for travel tips and recommendations. Default is None", + default="", + type=str, ) parser.add_argument( @@ -581,27 +669,53 @@ def main(): parser.add_argument( "--model_mode", - help="Export and inference batch_prefill mode, kv mode or hybrid(TBD) mode", + help="Export and inference prefill mode, kv mode or hybrid mode", default="kv", - choices=["batch_prefill", "kv", "hybrid"], + choices=["prefill", "kv", "hybrid"], type=str, ) + parser.add_argument( + "--prefill_seq_len", + help="Ouput sequence length for llama. Use this option for prefill or hybrid mode", + default=32, + type=int, + ) + + parser.add_argument( + "--kv_seq_len", + help="Ouput sequence length for llama. Use this option for kv or hybrid mode", + default=512, + type=int, + ) + args = parser.parse_args() if args.compile_only and args.pre_gen_pte: exit("Cannot set both compile_only and pre_gen_pte as true") + if args.model_mode == "kv": + pte_filename = "kv_llama3_2_qnn" + elif args.model_mode == "prefill": + pte_filename = "prefill_llama3_2_qnn" + elif args.model_mode == "hybrid": + assert ( + args.kv_seq_len >= args.prefill_seq_len + ), "Please ensure kv_seq_len is >= prefill_seq_len" + pte_filename = "hybrid_llama3_2_qnn" + else: + raise RuntimeError(f"No such model_mode {args.model_mode}.") + if args.pre_gen_pte: - inference(args, args.pre_gen_pte) + inference(args, pte_filename, args.pre_gen_pte) exit(f"Finish the running pre_gen_pte from {args.pre_gen_pte}") if args.compile_only: - compile(args) + compile(args, pte_filename) exit(f"Finish compile_only and save to {args.artifact}") try: - compile(args) - inference(args) + compile(args, pte_filename) + inference(args, pte_filename) except Exception as e: if args.ip and args.port != -1: with Client((args.ip, args.port)) as conn: diff --git a/examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp b/examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp index 554e3ba9329..d05def243ba 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama3_2/qnn_llama3_2_runner.cpp @@ -18,6 +18,7 @@ #include #include #include +#include DEFINE_string( model_path, @@ -46,7 +47,7 @@ DEFINE_int32( DEFINE_int32( eval_mode, 0, - "0: PromptProcessor(batch_prefill) / 1: TokenGenerator(kv) / 2: HybridMode (TBD)"); + "0: PromptProcessor(prefill) / 1: TokenGenerator(kv) / 2: HybridMode (prefill+kv)"); int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -57,14 +58,21 @@ int main(int argc, char** argv) { FLAGS_tokenizer_path.c_str(), FLAGS_temperature, FLAGS_eval_mode); - - // generate tokens & store inference output + std::vector buf; + buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char std::ofstream fout(FLAGS_output_path.c_str()); + auto callback = [&](const std::string& piece) { + for (const char c : piece) { + buf.push_back(c); + } + }; + // generate tokens & store inference output runner.generate( - FLAGS_prompt, - FLAGS_system_prompt, FLAGS_seq_len, - [&](const std::string& piece) { fout << piece; }); + FLAGS_prompt.c_str(), + FLAGS_system_prompt.c_str(), + callback); + fout.write(buf.data(), buf.size()); fout.close(); return 0; } diff --git a/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.cpp b/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.cpp index 9b37d056cf5..ccf386309c9 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.cpp +++ b/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.cpp @@ -6,11 +6,9 @@ * LICENSE file in the root directory of this source tree. */ -#include -#include - #include #include +#include using executorch::aten::Tensor; using executorch::aten::TensorImpl; @@ -23,10 +21,7 @@ using executorch::runtime::TensorInfo; namespace example { Memory::Memory(std::vector>& modules) - : data_ptr_(nullptr, [](void*) {}), - input_tensors_(modules.size()), - output_tensors_(modules.size()), - modules_(modules) {} + : data_ptr_(nullptr, [](void*) {}), modules_(modules) {} Memory::~Memory() {} @@ -34,19 +29,23 @@ void* Memory::get_mutable_ptr() { return data_ptr_.get(); } -std::vector Memory::get_input_tensors(int shard_index) { +std::vector Memory::get_input_tensors( + int shard_index, + const std::string& method_name) { std::vector ret; ret.reserve(input_tensors_.size()); - for (TensorImpl* impl : input_tensors_[shard_index]) { + for (TensorImpl* impl : input_tensors_[method_name][shard_index]) { ret.emplace_back(Tensor(impl)); } return ret; } -std::vector Memory::get_output_tensors(int shard_index) { +std::vector Memory::get_output_tensors( + int shard_index, + const std::string& method_name) { std::vector ret; - ret.reserve(output_tensors_.size()); - for (TensorImpl* impl : output_tensors_[shard_index]) { + ret.reserve(output_tensors_[method_name][shard_index].size()); + for (TensorImpl* impl : output_tensors_[method_name][shard_index]) { ret.emplace_back(Tensor(impl)); } return ret; @@ -54,26 +53,116 @@ std::vector Memory::get_output_tensors(int shard_index) { HybridMemory::HybridMemory( std::vector>& modules, - int32_t max_seq_len, + int32_t prefill_cache_len, + int32_t kv_cache_len, int32_t vocab_size, int32_t num_layers, int32_t head_dim, - int32_t num_heads) + int32_t num_heads, + EvalMode eval_mode, + const std::string& prefill_forward_name, + const std::string& kv_forward_name) : Memory(modules), shard_layers_({num_layers}), - max_seq_len_(max_seq_len), + kv_cache_len_(kv_cache_len), + prefill_cache_len_(prefill_cache_len), vocab_size_(vocab_size), num_layers_(num_layers), head_dim_(head_dim), - num_heads_(num_heads) { + num_heads_(num_heads), + eval_mode_(eval_mode), + prefill_forward_name_(prefill_forward_name), + kv_forward_name_(kv_forward_name) { + if (!prefill_forward_name_.empty()) { + input_tensors_[prefill_forward_name_] = + std::vector>(modules.size()); + output_tensors_[prefill_forward_name_] = + std::vector>(modules.size()); + k_cache_in_[prefill_forward_name_] = + std::vector>(); + v_cache_in_[prefill_forward_name_] = + std::vector>(); + k_cache_out_[prefill_forward_name_] = + std::vector>(); + v_cache_out_[prefill_forward_name_] = + std::vector>(); + } + if (!kv_forward_name_.empty()) { + input_tensors_[kv_forward_name_] = + std::vector>(modules.size()); + output_tensors_[kv_forward_name_] = + std::vector>(modules.size()); + k_cache_in_[kv_forward_name_] = + std::vector>(); + v_cache_in_[kv_forward_name_] = + std::vector>(); + k_cache_out_[kv_forward_name_] = + std::vector>(); + v_cache_out_[kv_forward_name_] = + std::vector>(); + } + data_ptr_ = std::unique_ptr( new IO, [](void* ptr) { delete static_cast(ptr); }); } -void HybridMemory::prepare_kv_io( - const std::vector>& methods_meta) { +void HybridMemory::init_io() { IO* ptr = static_cast(data_ptr_.get()); std::memset(ptr, 0, sizeof(IO)); + + int32_t max_cache_len = std::max(kv_cache_len_, prefill_cache_len_); + int32_t k_in_size = (head_dim_ + 1) * max_cache_len; + int32_t v_cache_size = (num_heads_ + 1) * max_cache_len * head_dim_; + int32_t k_cache_out_size = num_heads_ * head_dim_; + if (eval_mode_ == EvalMode::kHybrid || eval_mode_ == EvalMode::kPrefill) { + k_cache_out_size *= prefill_cache_len_; + } + + // Init kv vector shape, general enough to be shared across all 3 modes. + ptr->k_cache_out.reserve(num_layers_); + ptr->v_cache.reserve(num_layers_); + for (int layer = 0; layer < num_layers_; layer++) { + ptr->k_cache_out.emplace_back(std::vector(k_cache_out_size)); + ptr->v_cache.emplace_back(std::vector(v_cache_size)); + } + + auto init_prefill = [&]() { + ptr->prefill_input_toks.resize(prefill_cache_len_); + ptr->prefill_atten_mask.resize(prefill_cache_len_ * prefill_cache_len_); + ptr->prefill_logits.resize(prefill_cache_len_ * vocab_size_); + }; + + auto init_kv = [&]() { + ptr->kv_logits.resize(vocab_size_); + ptr->kv_attention_mask.resize((kv_cache_len_ + 1), -255); + ptr->k_cache.reserve(num_layers_); + for (int layer = 0; layer < num_layers_; layer++) { + ptr->k_cache.emplace_back(); + ptr->k_cache[layer].reserve(num_heads_); + for (int head = 0; head < num_heads_; head++) { + ptr->k_cache[layer].emplace_back(std::vector(k_in_size)); + } + } + }; + + switch (eval_mode_) { + case EvalMode::kPrefill: + init_prefill(); + break; + case EvalMode::kKVCached: + init_kv(); + break; + case EvalMode::kHybrid: + init_prefill(); + init_kv(); + break; + default: + break; + } +} + +void HybridMemory::prepare_kv_io( + const std::vector>& methods_meta) { for (int i = 0; i < modules_.size(); ++i) { ET_CHECK_MSG( methods_meta[i].ok(), @@ -81,23 +170,8 @@ void HybridMemory::prepare_kv_io( static_cast(methods_meta[i].error())); } - // Init IO vector shape - // atten_mask - ptr->logits.resize(vocab_size_); - ptr->attention_mask.resize( - max_seq_len_, -255); // attention mask shape should be [1, ctx_length] - // kv - int32_t k_in_size = (head_dim_ + 1) * (max_seq_len_ - 1); - int32_t k_out_size = num_heads_ * head_dim_; - int32_t v_cache_size = (num_heads_ + 1) * (max_seq_len_ - 1) * head_dim_; - for (int layer = 0; layer < num_layers_; layer++) { - ptr->k_cache.emplace_back(); - for (int head = 0; head < num_heads_; head++) { - ptr->k_cache[layer].emplace_back(std::vector(k_in_size)); - } - ptr->k_cache_out.emplace_back(std::vector(k_out_size)); - ptr->v_cache.emplace_back(std::vector(v_cache_size)); - } + ET_CHECK_MSG(!(kv_forward_name_.empty()), "kv forward name is empty"); + IO* ptr = static_cast(data_ptr_.get()); // [I]: input_tokens Result input_tok = methods_meta[0]->input_tensor_meta(0); @@ -107,7 +181,7 @@ void HybridMemory::prepare_kv_io( const_cast(input_tok->sizes().data()), &ptr->input_tok, const_cast(input_tok->dim_order().data())); - input_tensors_[0].push_back(input_tok_.get()); + input_tensors_[kv_forward_name_][0].push_back(input_tok_.get()); // [I]: atten_mask Result atten_mask = methods_meta[0]->input_tensor_meta(1); @@ -115,9 +189,9 @@ void HybridMemory::prepare_kv_io( atten_mask->scalar_type(), atten_mask->sizes().size(), const_cast(atten_mask->sizes().data()), - ptr->attention_mask.data(), + ptr->kv_attention_mask.data(), const_cast(atten_mask->dim_order().data())); - input_tensors_[0].push_back(attention_mask_.get()); + input_tensors_[kv_forward_name_][0].push_back(attention_mask_.get()); // [I]: input_pos Result input_pos = methods_meta[0]->input_tensor_meta(2); @@ -127,13 +201,11 @@ void HybridMemory::prepare_kv_io( const_cast(input_pos->sizes().data()), &ptr->input_pos, const_cast(input_pos->dim_order().data())); - input_tensors_[0].push_back(input_pos_.get()); + input_tensors_[kv_forward_name_][0].push_back(input_pos_.get()); // [I] kv_cache int index = 3; // bypass input_tokens, input_pos, atten_mask - for (int offset = 0, - shard_index = 0, - v_stride = (max_seq_len_ - 1) * head_dim_; + for (int offset = 0, shard_index = 0, v_stride = kv_cache_len_ * head_dim_; shard_index < modules_.size(); offset += shard_layers_[shard_index], shard_index++) { for (int cache_group = 0; cache_group < 2; ++cache_group) { @@ -142,7 +214,8 @@ void HybridMemory::prepare_kv_io( Result kv_cache = methods_meta[shard_index]->input_tensor_meta(index); std::vector>& cache = - (cache_group == 0 ? k_cache_in_ : v_cache_in_); + (cache_group == 0 ? k_cache_in_[kv_forward_name_] + : v_cache_in_[kv_forward_name_]); void* cache_ptr = (cache_group == 0) ? static_cast(ptr->k_cache[layer + offset][head].data()) : static_cast( @@ -155,7 +228,8 @@ void HybridMemory::prepare_kv_io( cache_ptr, const_cast( kv_cache->dim_order().data()))); - input_tensors_[shard_index].push_back(cache.back().get()); + input_tensors_[kv_forward_name_][shard_index].push_back( + cache.back().get()); } } } @@ -165,13 +239,14 @@ void HybridMemory::prepare_kv_io( int logit_index = 0; Result logits = methods_meta[modules_.size() - 1]->output_tensor_meta(logit_index); - logits_ = std::make_unique( + kv_logits_ = std::make_unique( logits->scalar_type(), logits->sizes().size(), const_cast(logits->sizes().data()), - ptr->logits.data(), + ptr->kv_logits.data(), const_cast(logits->dim_order().data())); - output_tensors_[modules_.size() - 1].push_back(logits_.get()); + output_tensors_[kv_forward_name_][modules_.size() - 1].push_back( + kv_logits_.get()); // [O] kv_cache index = 1; @@ -179,9 +254,7 @@ void HybridMemory::prepare_kv_io( // For k, we store it in k_cache_out and update to k_cache later. // For v, we append the output to the end of v_cache, // which serves as both input and output. - for (int offset = 0, - shard_index = 0, - v_stride = (max_seq_len_ - 1) * head_dim_; + for (int offset = 0, shard_index = 0, v_stride = kv_cache_len_ * head_dim_; shard_index < modules_.size(); offset += shard_layers_[shard_index], shard_index++) { for (int cache_group = 0; cache_group < 2; ++cache_group) { @@ -190,7 +263,8 @@ void HybridMemory::prepare_kv_io( Result kv_cache = methods_meta[shard_index]->output_tensor_meta(index); std::vector>& cache = - (cache_group == 0 ? k_cache_out_ : v_cache_out_); + (cache_group == 0 ? k_cache_out_[kv_forward_name_] + : v_cache_out_[kv_forward_name_]); void* cache_ptr = (cache_group == 0) ? static_cast( ptr->k_cache_out[layer + offset].data() + @@ -205,7 +279,8 @@ void HybridMemory::prepare_kv_io( cache_ptr, const_cast( kv_cache->dim_order().data()))); - output_tensors_[shard_index].push_back(cache.back().get()); + output_tensors_[kv_forward_name_][shard_index].push_back( + cache.back().get()); } } } @@ -214,8 +289,6 @@ void HybridMemory::prepare_kv_io( void HybridMemory::prepare_prefill_io( const std::vector>& methods_meta) { - IO* ptr = static_cast(data_ptr_.get()); - std::memset(ptr, 0, sizeof(IO)); for (int i = 0; i < modules_.size(); ++i) { ET_CHECK_MSG( methods_meta[i].ok(), @@ -223,23 +296,10 @@ void HybridMemory::prepare_prefill_io( static_cast(methods_meta[i].error())); } - // Parse some IO info from method meta - // cache_len should be max_seq_len - 1 - int cache_len = methods_meta[0]->input_tensor_meta(0)->sizes()[1]; - - // TODO: Combine vector init with KV mode once Hybrid mode is enabled - // as it shares some common data structure. - // Init IO vector shape - ptr->prefill_input_toks.resize(cache_len); - ptr->prefill_atten_mask.resize(cache_len * cache_len); - ptr->prefill_logits.resize(cache_len * vocab_size_); - // Init kv vector shape - int32_t k_cache_out_size = num_heads_ * head_dim_ * cache_len; - int32_t v_cache_size = (num_heads_ + 1) * cache_len * head_dim_; - for (int layer = 0; layer < num_layers_; layer++) { - ptr->k_cache_out.emplace_back(std::vector(k_cache_out_size)); - ptr->v_cache.emplace_back(std::vector(v_cache_size)); - } + ET_CHECK_MSG( + !(prefill_forward_name_.empty()), "prefill forward name is empty"); + + IO* ptr = static_cast(data_ptr_.get()); // [I]: pre_input_tokens Result prefill_input_toks = methods_meta[0]->input_tensor_meta(0); @@ -250,43 +310,54 @@ void HybridMemory::prepare_prefill_io( ptr->prefill_input_toks.data(), const_cast( prefill_input_toks->dim_order().data())); - input_tensors_[0].push_back(prefill_input_toks_.get()); + input_tensors_[prefill_forward_name_][0].push_back(prefill_input_toks_.get()); // [I]: prefill_attn_mask - for (int i = 0; i < cache_len; ++i) { - for (int j = 0; j < cache_len; ++j) { + for (int i = 0; i < prefill_cache_len_; ++i) { + for (int j = 0; j < prefill_cache_len_; ++j) { if (i < j) { - ptr->prefill_atten_mask[i * cache_len + j] = -255; + ptr->prefill_atten_mask[i * prefill_cache_len_ + j] = -255; } else { - ptr->prefill_atten_mask[i * cache_len + j] = 0; + ptr->prefill_atten_mask[i * prefill_cache_len_ + j] = 0; } } } - - Result prefill_attn_mask = methods_meta[0]->input_tensor_meta(1); + Result prefill_atten_mask = methods_meta[0]->input_tensor_meta(1); prefill_attn_mask_ = std::make_unique( - prefill_attn_mask->scalar_type(), - prefill_attn_mask->sizes().size(), - const_cast(prefill_attn_mask->sizes().data()), + prefill_atten_mask->scalar_type(), + prefill_atten_mask->sizes().size(), + const_cast(prefill_atten_mask->sizes().data()), ptr->prefill_atten_mask.data(), const_cast( - prefill_attn_mask->dim_order().data())); - input_tensors_[0].push_back(prefill_attn_mask_.get()); - + prefill_atten_mask->dim_order().data())); + input_tensors_[prefill_forward_name_][0].push_back(prefill_attn_mask_.get()); // [O]: logits int logit_index = 0; Result logits = methods_meta[modules_.size() - 1]->output_tensor_meta(logit_index); - logits_ = std::make_unique( + prefill_logits_ = std::make_unique( logits->scalar_type(), logits->sizes().size(), const_cast(logits->sizes().data()), ptr->prefill_logits.data(), const_cast(logits->dim_order().data())); - output_tensors_[modules_.size() - 1].push_back(logits_.get()); + output_tensors_[prefill_forward_name_][modules_.size() - 1].push_back( + prefill_logits_.get()); + // [O] kv_cache int index = 1; - for (int offset = 0, shard_index = 0, cache_stride = cache_len * head_dim_; - shard_index < modules_.size(); + // prefill_k_stride should be equal to prefill_v_stride in prefill mode. + // In hybrid mode, we use kv mode cache len for v stride since we want to + // update prefill's result onto kv modes input. + int32_t prefill_k_stride = prefill_cache_len_ * head_dim_; + int32_t prefill_v_stride = + std::max(prefill_cache_len_, kv_cache_len_) * head_dim_; + + if (eval_mode_ == EvalMode::kPrefill) { + ET_CHECK_MSG( + prefill_k_stride == prefill_v_stride, + "prefill_k_stride should be equal to prefill_v_stride"); + } + for (int offset = 0, shard_index = 0; shard_index < modules_.size(); offset += shard_layers_[shard_index], shard_index++) { for (int cache_group = 0; cache_group < 2; ++cache_group) { for (int layer = 0; layer < shard_layers_[shard_index]; ++layer) { @@ -294,13 +365,15 @@ void HybridMemory::prepare_prefill_io( Result kv_cache = methods_meta[shard_index]->output_tensor_meta(index); std::vector>& cache = - (cache_group == 0 ? k_cache_out_ : v_cache_out_); + (cache_group == 0 ? k_cache_out_[prefill_forward_name_] + : v_cache_out_[prefill_forward_name_]); void* cache_ptr = (cache_group == 0) ? static_cast( ptr->k_cache_out[layer + offset].data() + - head * cache_stride) + head * prefill_k_stride) : static_cast( - ptr->v_cache[layer + offset].data() + head * cache_stride); + ptr->v_cache[layer + offset].data() + + (head + 1) * prefill_v_stride); cache.emplace_back(std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), @@ -308,37 +381,99 @@ void HybridMemory::prepare_prefill_io( cache_ptr, const_cast( kv_cache->dim_order().data()))); - output_tensors_[shard_index].push_back(cache.back().get()); + output_tensors_[prefill_forward_name_][shard_index].push_back( + cache.back().get()); } } } } } -void HybridMemory::update_io( +void HybridMemory::update_prefill_to_kv_io( + int64_t cur_token, + int64_t pos, + std::vector>& output_tensors) { + ET_CHECK_MSG(kv_cache_len_ != 0, "k_cache_len_ should not equal to 0"); + ET_CHECK_MSG( + prefill_cache_len_ != 0, "prefill_cache_len_ should not equal to 0"); + IO* ptr = static_cast(data_ptr_.get()); + + ptr->input_tok = static_cast(cur_token); + ptr->input_pos = static_cast(pos); + // If prompt len is 30, prefill will handle to pos = 30. + // At this point, pos should be 31. + for (int i = 0; i < pos + 1; i++) { + ptr->kv_attention_mask[kv_cache_len_ - i] = 0; + } + + // update v_cache + std::vector>& v_cache_in = + v_cache_in_[kv_forward_name_]; + std::vector>& v_cache_out = + v_cache_out_[kv_forward_name_]; + for (int i = 0, v_cache_stride = head_dim_ * pos; i < v_cache_in.size(); + i++) { + v_cache_in[i]->set_data( + v_cache_in[i]->mutable_data() + v_cache_stride); + v_cache_out[i]->set_data( + v_cache_out[i]->mutable_data() + v_cache_stride); + } + for (int shard = 0; shard < output_tensors.size(); shard++) { + for (int index = 0; index < output_tensors[shard].size(); index++) { + ET_CHECK_MSG( + modules_[shard]->set_output( + kv_forward_name_, output_tensors[shard][index], index) == + Error::Ok, + "Failed to set output tensor for module %d's %d'th output " + "while updating kv_cache output tensors", + shard, + index); + } + } + + std::vector>& k_cache_in = + k_cache_in_[kv_forward_name_]; + std::vector>& k_cache_out = + k_cache_out_[prefill_forward_name_]; + for (int i = 0; i < k_cache_in.size(); ++i) { + uint8_t* ptr_in = k_cache_in[i]->mutable_data(); + const uint8_t* ptr_out = k_cache_out[i]->data(); + for (size_t j = 0, offset = kv_cache_len_; j < head_dim_; + ++j, offset += kv_cache_len_) { + for (int k = 0, k_stride = j * prefill_cache_len_; k < pos; k++) { + ptr_in[offset + k] = ptr_out[k_stride + k]; + } + } + k_cache_in[i]->set_data(ptr_in + pos); + } +} + +void HybridMemory::update_kv_io( int64_t cur_token, int64_t pos, std::vector>& output_tensors) { IO* ptr = static_cast(data_ptr_.get()); - int seq_len = (max_seq_len_ - 1); // update input_tok ptr->input_tok = static_cast(cur_token); // update position_ids ptr->input_pos = static_cast(pos); // update causal mask for next token - ptr->attention_mask[seq_len - pos] = 0; + ptr->kv_attention_mask[kv_cache_len_ - pos] = 0; // update v_cache - for (int i = 0; i < v_cache_in_.size(); i++) { - v_cache_in_[i]->set_data( - v_cache_in_[i]->mutable_data() + head_dim_); - v_cache_out_[i]->set_data( - v_cache_out_[i]->mutable_data() + head_dim_); + auto& v_cache_in = v_cache_in_[kv_forward_name_]; + auto& v_cache_out = v_cache_out_[kv_forward_name_]; + for (int i = 0; i < v_cache_in.size(); i++) { + v_cache_in[i]->set_data(v_cache_in[i]->mutable_data() + head_dim_); + v_cache_out[i]->set_data( + v_cache_out[i]->mutable_data() + head_dim_); } + for (int shard = 0; shard < output_tensors.size(); shard++) { for (int index = 0; index < output_tensors[shard].size(); index++) { ET_CHECK_MSG( - modules_[shard]->set_output(output_tensors[shard][index], index) == + modules_[shard]->set_output( + kv_forward_name_, output_tensors[shard][index], index) == Error::Ok, "failed to set output tensor for module %d's %d'th output " "while updating kv_cache output tensors", @@ -347,15 +482,17 @@ void HybridMemory::update_io( } } + auto& k_cache_in = k_cache_in_[kv_forward_name_]; + auto& k_cache_out = k_cache_out_[kv_forward_name_]; // update k_cache by single thread, this part is cpu cache sensitive - for (int i = 0; i < k_cache_in_.size(); ++i) { - uint8_t* ptr_in = k_cache_in_[i]->mutable_data(); - const uint8_t* ptr_out = k_cache_out_[i]->data(); - for (size_t j = 0, offset = seq_len; j < head_dim_; - ++j, offset += seq_len) { + for (int i = 0; i < k_cache_in.size(); ++i) { + uint8_t* ptr_in = k_cache_in[i]->mutable_data(); + const uint8_t* ptr_out = k_cache_out[i]->data(); + for (size_t j = 0, offset = kv_cache_len_; j < head_dim_; + ++j, offset += kv_cache_len_) { ptr_in[offset] = ptr_out[j]; } - k_cache_in_[i]->set_data(ptr_in + 1); + k_cache_in[i]->set_data(ptr_in + 1); } } diff --git a/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.h b/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.h index 31ed351ef4b..ca3a8848871 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.h +++ b/examples/qualcomm/oss_scripts/llama3_2/runner/io_memory.h @@ -9,11 +9,6 @@ #pragma once #include -#include -#include -#include -#include -#include #include #include @@ -21,10 +16,17 @@ namespace example { +enum EvalMode { + kPrefill = 0, + kKVCached, + kHybrid, + kUnsupported, +}; class Memory { public: Memory(std::vector>& modules); virtual ~Memory(); + virtual void init_io() = 0; virtual void prepare_prefill_io( const std::vector< executorch::runtime::Result>& @@ -33,18 +35,32 @@ class Memory { const std::vector< executorch::runtime::Result>& methods_meta) = 0; - virtual void update_io( + virtual void update_prefill_to_kv_io( + int64_t cur_token, + int64_t pos, + std::vector>& output_tensors) = 0; + virtual void update_kv_io( int64_t cur_token, int64_t pos, std::vector>& output_tensors) = 0; void* get_mutable_ptr(); - std::vector get_input_tensors(int shard_index); - std::vector get_output_tensors(int shard_index); + std::vector get_input_tensors( + int shard_index, + const std::string& method_name); + std::vector get_output_tensors( + int shard_index, + const std::string& method_name); protected: std::unique_ptr data_ptr_; - std::vector> input_tensors_; - std::vector> output_tensors_; + std::unordered_map< + std::string, + std::vector>> + input_tensors_; + std::unordered_map< + std::string, + std::vector>> + output_tensors_; std::vector> modules_; }; @@ -52,11 +68,17 @@ class HybridMemory : public Memory { public: HybridMemory( std::vector>& modules, - int32_t max_seq_len, + int32_t prefill_cache_len, + int32_t kv_cache_len, int32_t vocab_size, int32_t num_layers, int32_t head_dim, - int32_t num_heads); + int32_t num_heads, + EvalMode eval_mode, + const std::string& prefill_forward_name, + const std::string& kv_forward_name); + + void init_io() override; void prepare_prefill_io( const std::vector< executorch::runtime::Result>& @@ -65,7 +87,12 @@ class HybridMemory : public Memory { const std::vector< executorch::runtime::Result>& methods_meta) override; - void update_io( + void update_prefill_to_kv_io( + int64_t cur_token, + int64_t pos, + std::vector>& output_tensors) + override; + void update_kv_io( int64_t cur_token, int64_t pos, std::vector>& output_tensors) @@ -73,11 +100,11 @@ class HybridMemory : public Memory { struct IO { int32_t input_tok; int32_t input_pos; - std::vector attention_mask; std::vector>> k_cache; std::vector> v_cache; std::vector> k_cache_out; - std::vector logits; + std::vector kv_attention_mask; + std::vector kv_logits; std::vector prefill_input_toks; std::vector prefill_atten_mask; std::vector prefill_logits; @@ -90,17 +117,34 @@ class HybridMemory : public Memory { std::unique_ptr attention_mask_; std::unique_ptr prefill_input_toks_; std::unique_ptr prefill_attn_mask_; - std::vector> k_cache_in_; - std::vector> v_cache_in_; - std::vector> k_cache_out_; - std::vector> v_cache_out_; - std::unique_ptr logits_; + std::unique_ptr prefill_logits_; + std::unordered_map< + std::string, + std::vector>> + k_cache_in_; + std::unordered_map< + std::string, + std::vector>> + v_cache_in_; + std::unordered_map< + std::string, + std::vector>> + k_cache_out_; + std::unordered_map< + std::string, + std::vector>> + v_cache_out_; + std::unique_ptr kv_logits_; std::vector shard_layers_; - int32_t max_seq_len_; + int32_t kv_cache_len_{0}; + int32_t prefill_cache_len_{0}; int32_t vocab_size_; int32_t num_layers_; int32_t head_dim_; int32_t num_heads_; + EvalMode eval_mode_; + std::string prefill_forward_name_; + std::string kv_forward_name_; }; } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp index 80da5b98873..e87240dfdfe 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.cpp @@ -13,12 +13,11 @@ #include #include #include +#include #include #include #include -#include #include -#include #include using executorch::aten::Tensor; @@ -47,38 +46,14 @@ Runner::Runner( n_eos_(1), tokenizer_path_(tokenizer_path), temperature_(temperature), - eval_mode_(eval_mode) { + eval_mode_(static_cast(eval_mode)) { for (size_t i = 0; i < models_path.size(); ++i) { modules_.push_back(std::make_shared( models_path[i], Module::LoadMode::MmapUseMlockIgnoreErrors)); ET_LOG(Info, "creating module: model_path=%s", models_path[i].c_str()); } ET_LOG(Info, "creating runner: tokenizer_path=%s", tokenizer_path_.c_str()); - - int64_t max_seq_len = getMetadataHelper("get_max_seq_len", -1); - int64_t vocab_size = getMetadataHelper("get_vocab_size", -1); - int64_t num_layers = getMetadataHelper("get_n_layers", -1); - int64_t head_dim = getMetadataHelper("get_head_dim", -1); - int64_t num_heads = getMetadataHelper("get_n_kv_heads", -1); - ET_CHECK_MSG(max_seq_len != -1, "Could not retrieve max seq len"); - ET_CHECK_MSG(vocab_size != -1, "Could not retrieve vocab size"); - ET_CHECK_MSG(num_layers != -1, "Could not retrieve num layers"); - ET_CHECK_MSG(head_dim != -1, "Could not retrieve head dimension"); - ET_CHECK_MSG(num_heads != -1, "Could not retrieve num heads"); - - max_seq_len_ = max_seq_len; - vocab_size_ = vocab_size; - tokenizer_ = example::get_tiktoken_for_llama(); - Error err = tokenizer_->load(tokenizer_path_); - ET_CHECK_MSG( - err == Error::Ok, "failed to load tokenizer %s", tokenizer_path_.c_str()); - eos_id_.insert(tokenizer_->encode("<|eot_id|>", 0, 0).get()[0]); - bos_id_ = tokenizer_->bos_tok(); - eos_id_.insert(tokenizer_->eos_tok()); - - io_mem_ = std::make_unique( - modules_, max_seq_len_, vocab_size_, num_layers, head_dim, num_heads); - ET_LOG(Info, "creating io_memory"); + ET_LOG(Info, "eval mode=%d", eval_mode); } bool Runner::is_loaded() const { @@ -93,10 +68,112 @@ Error Runner::load() { if (is_loaded()) { return Error::Ok; } + + switch (eval_mode_) { + case EvalMode::kPrefill: + prefill_forward_name_ = "forward"; + method_names_.emplace_back(prefill_forward_name_); + break; + case EvalMode::kKVCached: + kv_forward_name_ = "forward"; + method_names_.emplace_back(kv_forward_name_); + break; + case EvalMode::kHybrid: + prefill_forward_name_ = "prefill_forward"; + kv_forward_name_ = "kv_forward"; + method_names_.emplace_back(prefill_forward_name_); + method_names_.emplace_back(kv_forward_name_); + break; + case EvalMode::kUnsupported: + ET_CHECK_MSG(false, "Unsupported llama version"); + break; + } + for (std::shared_ptr& module : modules_) { - ET_CHECK_OK_OR_RETURN_ERROR(module->load_method("forward")); + if (!prefill_forward_name_.empty()) { + ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(prefill_forward_name_)); + } + if (!kv_forward_name_.empty()) { + ET_CHECK_OK_OR_RETURN_ERROR(module->load_method(kv_forward_name_)); + } } + if (!prefill_forward_name_.empty()) { + // Use input tokens length to retrieve prefill cache len + // Cache len equals to prefill model seq_len - 1 + prefill_cache_len_ = get_methods_meta(prefill_forward_name_)[0] + ->input_tensor_meta(0) + ->sizes()[1]; + } + if (!kv_forward_name_.empty()) { + // Use k cache length to retirieve kv cache len + // Cache len equals to kv model seq_len - 1 + kv_cache_len_ = + get_methods_meta(kv_forward_name_)[0]->input_tensor_meta(3)->sizes()[2]; + } + + // retrieve any method meta, can be either prefill or kv + // Try avoid getMetadataHelper as it is time consuming. + auto method_meta = get_methods_meta(method_names_[0])[0].get(); + int64_t num_layers = getMetadataHelper("get_n_layers", -1); + int64_t head_dim = method_meta.output_tensor_meta(1)->sizes()[1]; // k_cache + int64_t num_heads = (method_meta.num_outputs() - 1) / (num_layers * 2); + vocab_size_ = method_meta.output_tensor_meta(0)->sizes()[2]; // logit_tensor + ET_CHECK_MSG(num_layers != -1, "Could not retrieve num layers"); + + io_mem_ = std::make_unique( + modules_, + prefill_cache_len_, + kv_cache_len_, + vocab_size_, + num_layers, + head_dim, + num_heads, + eval_mode_, + prefill_forward_name_, + kv_forward_name_); + ET_LOG(Info, "creating io_memory"); + + // prepare io + io_mem_->init_io(); + switch (eval_mode_) { + case EvalMode::kPrefill: + io_mem_->prepare_prefill_io(get_methods_meta(prefill_forward_name_)); + break; + case EvalMode::kKVCached: + io_mem_->prepare_kv_io(get_methods_meta(kv_forward_name_)); + break; + case EvalMode::kHybrid: + io_mem_->prepare_prefill_io(get_methods_meta(prefill_forward_name_)); + io_mem_->prepare_kv_io(get_methods_meta(kv_forward_name_)); + break; + case EvalMode::kUnsupported: + ET_CHECK_MSG(false, "unsupported mode"); + break; + } + + // llama3 tokenizer + tokenizer_ = example::get_tiktoken_for_llama(); + Error err = tokenizer_->load(tokenizer_path_); + if (err == Error::InvalidArgument) { + ET_LOG( + Info, + "Failed to load %s as a Tiktoken artifact, trying BPE tokenizer", + tokenizer_path_.c_str()); + tokenizer_.reset(); + // llama2 tokenizer + tokenizer_ = std::make_unique(); + err = tokenizer_->load(tokenizer_path_); + ET_CHECK_MSG( + err == Error::Ok, + "failed to load tokenizer %s", + tokenizer_path_.c_str()); + } else { + eos_id_.insert(tokenizer_->encode("<|eot_id|>", 0, 0).get()[0]); + } + bos_id_ = tokenizer_->bos_tok(); + eos_id_.insert(tokenizer_->eos_tok()); + // create sampler sampler_ = std::make_unique( vocab_size_, @@ -104,13 +181,6 @@ Error Runner::load() { kTopp, static_cast(std::time(nullptr))); - // prepare io - auto methods_meta = get_methods_meta(); - if (eval_mode_ == EvalMode::kBatchPrefill) { - io_mem_->prepare_prefill_io(methods_meta); - } else { - io_mem_->prepare_kv_io(methods_meta); - } return Error::Ok; } @@ -132,7 +202,6 @@ T Runner::getMetadataHelper(std::string method_name, T default_val) { method_name.c_str(), (long long)default_val); } - ET_LOG(Info, "%s: %lld", method_name.c_str(), (long long)res); return res; } @@ -145,124 +214,149 @@ int32_t Runner::logitsToToken(const Tensor& logits_tensor) { return sampler_->sample(logits_last); } -void Runner::run_model_step(std::vector>& inputs) { +void Runner::run_model_step( + const std::string& method_name, + std::vector>& inputs) { for (size_t i = 0, num_modules = modules_.size(); i < num_modules; ++i) { - Result> outputs_res = modules_[i]->forward(inputs[i]); + Result> outputs_res = + modules_[i]->execute(method_name, inputs[i]); ET_CHECK_MSG( outputs_res.error() == Error::Ok, "shard %zu inference failed", i); } } Error Runner::generate( + int32_t seq_len, const std::string& prompt, const std::string& system_prompt, - int32_t seq_len, std::function token_callback, std::function stats_callback) { - ET_CHECK_MSG(!prompt.empty(), "prompt cannot be null"); - - std::vector> input_tensors, output_tensors; - std::vector> inputs; + std::unordered_map>> + input_tensors, output_tensors; + std::unordered_map>> inputs; if (!is_loaded()) { stats_.model_load_start_ms = time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(load()); - for (int i = 0; i < modules_.size(); ++i) { - input_tensors.emplace_back(io_mem_->get_input_tensors(i)); - output_tensors.emplace_back(io_mem_->get_output_tensors(i)); - for (size_t j = 0; j < output_tensors[i].size(); ++j) { - ET_CHECK_MSG( - modules_[i]->set_output(output_tensors[i][j], j) == Error::Ok, - "failed to set output tensor for module %d's %zu'th output", - i, - j); + for (auto method_name : method_names_) { + for (int i = 0; i < modules_.size(); ++i) { + input_tensors[method_name].emplace_back( + io_mem_->get_input_tensors(i, method_name)); + output_tensors[method_name].emplace_back( + io_mem_->get_output_tensors(i, method_name)); + for (size_t j = 0; j < output_tensors[method_name][i].size(); ++j) { + ET_CHECK_MSG( + modules_[i]->set_output( + method_name, output_tensors[method_name][i][j], j) == + Error::Ok, + "failed to set output tensor for module %d's %zu'th output", + i, + j); + } + inputs[method_name].emplace_back(std::vector( + begin(input_tensors[method_name][i]), + end(input_tensors[method_name][i]))); } - inputs.emplace_back( - std::vector(begin(input_tensors[i]), end(input_tensors[i]))); } - stats_.model_load_end_ms = time_in_ms(); } - std::string post_process_prompt; + stats_.model_load_end_ms = time_in_ms(); + stats_.inference_start_ms = time_in_ms(); + + ET_CHECK_MSG(!prompt.empty(), "prompt cannot be null"); if (!system_prompt.empty()) { - post_process_prompt.append( - "<|start_header_id|>system<|end_header_id|>\n\n"); - post_process_prompt.append(system_prompt); - post_process_prompt.append("<|eot_id|>\n"); + prompt_.append("<|start_header_id|>system<|end_header_id|>\n\n"); + prompt_.append(system_prompt); + prompt_.append("<|eot_id|>\n"); } - post_process_prompt.append("<|start_header_id|>user<|end_header_id|>\n\n"); - post_process_prompt.append(prompt); - post_process_prompt.append( - "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"); - token_callback("<|begin_of_text|>"); + prompt_.append("<|start_header_id|>user<|end_header_id|>\n\n"); + prompt_.append(prompt); + prompt_.append("<|eot_id|><|start_header_id|>assistant<|end_header_id|>"); - stats_.inference_start_ms = time_in_ms(); + if (token_callback) { + token_callback("<|begin_of_text|>"); + } - seq_len = (seq_len > 0 && seq_len <= max_seq_len_) ? seq_len : max_seq_len_; + int max_seq_len = std::max(prefill_cache_len_, kv_cache_len_) + 1; + seq_len = (seq_len > 0 && seq_len <= max_seq_len) ? seq_len : max_seq_len; Result> encode_res = - tokenizer_->encode(post_process_prompt, n_bos_, 0); + tokenizer_->encode(prompt_, n_bos_, 0); ET_CHECK_OK_OR_RETURN_ERROR( - encode_res.error(), - "failed to encode prompt %s", - post_process_prompt.c_str()); + encode_res.error(), "failed to encode prompt %s", prompt_.c_str()); std::vector prompt_tokens = encode_res.get(); int num_prompt_tokens = prompt_tokens.size(); - ET_CHECK_MSG(num_prompt_tokens < max_seq_len_, "max seq length exceeded"); + ET_CHECK_MSG(num_prompt_tokens < max_seq_len, "max seq length exceeded"); ET_CHECK_MSG( num_prompt_tokens < seq_len, "sequence length exceeded - please increase the seq_len value"); + if (eval_mode_ == EvalMode::kHybrid) { + int prefill_seq_len = get_methods_meta(prefill_forward_name_)[0] + ->input_tensor_meta(0) + ->sizes()[1] + + 1; + ET_CHECK_MSG( + num_prompt_tokens < prefill_seq_len, + "For hybrid mode, please ensure prompt length(%d) is less than prefill's seq_len(%d)", + num_prompt_tokens, + prefill_seq_len); + } int64_t pos = 0, prev_token, cur_token = prompt_tokens[0]; HybridMemory::IO* ptr = static_cast(io_mem_->get_mutable_ptr()); - if (eval_mode_ == EvalMode::kBatchPrefill) { + auto prefill_execute = [&](const std::string& method_name) { for (int i = 0; i < num_prompt_tokens; i++) { ptr->prefill_input_toks[i] = static_cast(prompt_tokens[i]); auto piece_res = tokenizer_->decode(prompt_tokens[i], prompt_tokens[i]); token_callback(piece_res.get()); } // inference - run_model_step(inputs); - Tensor& logits_tensor = output_tensors.back()[0]; + run_model_step(method_name, inputs[method_name]); + Tensor& logits_tensor = output_tensors[method_name].back()[0]; // offset to the meaningful logit we want. float* logits = logits_tensor.mutable_data_ptr() + (num_prompt_tokens - 1) * vocab_size_; prev_token = prompt_tokens[num_prompt_tokens - 1]; + long sample_start_time_ms = time_in_ms(); cur_token = sampler_->sample(logits); + stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms; stats_.first_token_ms = time_in_ms(); stats_.prompt_eval_end_ms = time_in_ms(); - long sample_start_time_ms = time_in_ms(); - stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms; auto piece_res = tokenizer_->decode(prev_token, cur_token); ET_CHECK(piece_res.ok()); if (token_callback) { token_callback(piece_res.get().c_str()); } pos += num_prompt_tokens; - } else { + }; + + auto kv_execute = [&](const std::string& method_name) { ptr->input_tok = static_cast(cur_token); - ptr->attention_mask[max_seq_len_ - 1] = 0; + ptr->kv_attention_mask[kv_cache_len_] = 0; while (pos < seq_len - 1) { // inference - run_model_step(inputs); - Tensor& logits_tensor = output_tensors.back()[0]; - - if (pos == num_prompt_tokens) { - stats_.first_token_ms = time_in_ms(); - } else if (pos == num_prompt_tokens - 1) { - stats_.prompt_eval_end_ms = time_in_ms(); + run_model_step(method_name, inputs[method_name]); + Tensor& logits_tensor = output_tensors[method_name].back()[0]; + + // hybrid mode will check these stats_ at prefill(prefill) + if (eval_mode_ == EvalMode::kKVCached) { + if (pos == num_prompt_tokens) { + stats_.first_token_ms = time_in_ms(); + } else if (pos == num_prompt_tokens - 1) { + stats_.prompt_eval_end_ms = time_in_ms(); + } } - long sample_start_time_ms = time_in_ms(); prev_token = cur_token; + long sample_start_time_ms = time_in_ms(); cur_token = logitsToToken(logits_tensor); stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms; if (pos < num_prompt_tokens - 1) { cur_token = prompt_tokens[pos + 1]; } - io_mem_->update_io(cur_token, ++pos, output_tensors); + io_mem_->update_kv_io(cur_token, ++pos, output_tensors[method_name]); auto piece_res = tokenizer_->decode(prev_token, cur_token); ET_CHECK(piece_res.ok()); @@ -275,8 +369,25 @@ Error Runner::generate( break; } } + }; + + switch (eval_mode_) { + case EvalMode::kPrefill: + prefill_execute(prefill_forward_name_); + break; + case EvalMode::kKVCached: + kv_execute(kv_forward_name_); + break; + case EvalMode::kHybrid: + prefill_execute(prefill_forward_name_); + io_mem_->update_prefill_to_kv_io( + cur_token, pos, output_tensors[kv_forward_name_]); + kv_execute(kv_forward_name_); + break; + default: + ET_CHECK_MSG(false, "Unsupported eval mode"); + break; } - stats_.inference_end_ms = time_in_ms(); if (pos == seq_len) { ET_LOG(Info, "\nSequence length (%i tokens) reached!", seq_len); @@ -348,7 +459,7 @@ void printReport(const Runner::Stats& stats) { ET_LOG( Info, "\tSampling time over %" PRIu64 " tokens:\t%f (seconds)", - stats.num_prompt_tokens + stats.num_generated_tokens, + stats.num_generated_tokens, (double)stats.aggregate_sampling_time_ms / stats.SCALING_FACTOR_UNITS_PER_SECOND); } @@ -370,11 +481,12 @@ std::string statsToJsonString(const Runner::Stats& stats) { } } // namespace -std::vector> Runner::get_methods_meta() { +std::vector> Runner::get_methods_meta( + std::string& method_name) { std::vector> methods_meta; methods_meta.reserve(modules_.size()); for (std::shared_ptr& module : modules_) { - methods_meta.emplace_back(module->method_meta("forward")); + methods_meta.emplace_back(module->method_meta(method_name)); } return methods_meta; } diff --git a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.h b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.h index b720697be5f..79b8370982b 100644 --- a/examples/qualcomm/oss_scripts/llama3_2/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama3_2/runner/runner.h @@ -61,29 +61,28 @@ class Runner { bool is_loaded() const; executorch::runtime::Error load(); executorch::runtime::Error generate( + int32_t seq_len, const std::string& prompt, const std::string& system_prompt, - int32_t seq_len, std::function token_callback = {}, std::function stats_callback = {}); void stop(); std::vector> - get_methods_meta(); + get_methods_meta(std::string& method_name); private: - enum EvalMode { - kBatchPrefill = 0, - kKVCached, - kUnsupported, - }; template T getMetadataHelper(std::string method_name, T default_val); template int32_t logitsToToken(const executorch::aten::Tensor& logits_tensor); void run_model_step( + const std::string& method_name, std::vector>& inputs); + std::string prompt_; + // metadata - int32_t max_seq_len_; + int32_t prefill_cache_len_{0}; + int32_t kv_cache_len_{0}; int32_t vocab_size_; int32_t bos_id_; std::unordered_set eos_id_; @@ -96,7 +95,10 @@ class Runner { std::unique_ptr sampler_; Stats stats_; std::unique_ptr io_mem_; - int32_t eval_mode_; + EvalMode eval_mode_; + std::string prefill_forward_name_; + std::string kv_forward_name_; + std::vector method_names_; }; } // namespace example