From d1c2fecba8b35eeec37d1259c2c865813f516ca0 Mon Sep 17 00:00:00 2001 From: DannyYuyang-quic Date: Wed, 21 May 2025 12:13:59 +0800 Subject: [PATCH] Qualcomm AI Engine Direct - GA model enablement (T5) Summary: - e2e script / test case for GA T5 model - perf: 16a8w avg encoding time: 4.09ms/inf, avg decoding time: 6ms/inf (SM8750) - acc: F1 Score ~= 76% in SQuAD - add QA dataset for Seq2SeqLM benchmarking --- backends/qualcomm/tests/test_qnn_delegate.py | 34 + backends/qualcomm/tests/utils.py | 1 + examples/qualcomm/CMakeLists.txt | 3 + .../qualcomm/oss_scripts/t5/CMakeLists.txt | 45 ++ .../qualcomm/oss_scripts/t5/qnn_t5_runner.cpp | 137 ++++ .../oss_scripts/t5/runner/decoder.cpp | 58 ++ .../qualcomm/oss_scripts/t5/runner/decoder.h | 52 ++ .../oss_scripts/t5/runner/encoder.cpp | 53 ++ .../qualcomm/oss_scripts/t5/runner/encoder.h | 35 + .../qualcomm/oss_scripts/t5/runner/runner.cpp | 234 +++++++ .../qualcomm/oss_scripts/t5/runner/runner.h | 74 ++ examples/qualcomm/oss_scripts/t5/t5.py | 361 ++++++++++ examples/qualcomm/oss_scripts/t5/t5_model.py | 632 ++++++++++++++++++ examples/qualcomm/utils.py | 210 +++++- 14 files changed, 1921 insertions(+), 8 deletions(-) create mode 100644 examples/qualcomm/oss_scripts/t5/CMakeLists.txt create mode 100644 examples/qualcomm/oss_scripts/t5/qnn_t5_runner.cpp create mode 100644 examples/qualcomm/oss_scripts/t5/runner/decoder.cpp create mode 100644 examples/qualcomm/oss_scripts/t5/runner/decoder.h create mode 100644 examples/qualcomm/oss_scripts/t5/runner/encoder.cpp create mode 100644 examples/qualcomm/oss_scripts/t5/runner/encoder.h create mode 100644 examples/qualcomm/oss_scripts/t5/runner/runner.cpp create mode 100644 examples/qualcomm/oss_scripts/t5/runner/runner.h create mode 100644 examples/qualcomm/oss_scripts/t5/t5.py create mode 100644 examples/qualcomm/oss_scripts/t5/t5_model.py diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 22e0d1cc219..d4b7343daae 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -4989,6 +4989,40 @@ def test_swin_transformer(self): self.assertGreaterEqual(msg["top_1"], 60) self.assertGreaterEqual(msg["top_5"], 80) + def test_t5(self): + if not self.required_envs([self.qa_dataset]): + self.skipTest("missing required envs") + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/t5/t5.py", + "--dataset", + self.sentence_dataset, + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--ip", + self.ip, + "--port", + str(self.port), + ] + if self.host: + cmds.extend(["--host", self.host]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["f1"], 0.7) + def test_whisper(self): if not self.required_envs(): self.skipTest("missing required envs") diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 2e923b92250..fd2d10e2b93 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -183,6 +183,7 @@ class TestQNN(unittest.TestCase): executorch_root: str = "" artifact_dir: str = "" image_dataset: str = "" + qa_dataset: str = "" sentence_dataset: str = "" pretrained_weight: str = "" enable_profile: bool = False diff --git a/examples/qualcomm/CMakeLists.txt b/examples/qualcomm/CMakeLists.txt index 69fa9a0b0d4..67aa9bb4b05 100644 --- a/examples/qualcomm/CMakeLists.txt +++ b/examples/qualcomm/CMakeLists.txt @@ -90,6 +90,9 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama) # build qnn_mimi_decoder_runner add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/moshi) +# build qnn_t5_runner for t5 +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/t5) + # build qnn_whisper_runner for whisper add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/whisper) diff --git a/examples/qualcomm/oss_scripts/t5/CMakeLists.txt b/examples/qualcomm/oss_scripts/t5/CMakeLists.txt new file mode 100644 index 00000000000..4ee42b69449 --- /dev/null +++ b/examples/qualcomm/oss_scripts/t5/CMakeLists.txt @@ -0,0 +1,45 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +# preprocess qnn runner src files for t5 +set(_qnn_t5_runner__srcs + ${CMAKE_CURRENT_LIST_DIR}/qnn_t5_runner.cpp + ${CMAKE_CURRENT_LIST_DIR}/runner/decoder.cpp + ${CMAKE_CURRENT_LIST_DIR}/runner/decoder.h + ${CMAKE_CURRENT_LIST_DIR}/runner/encoder.cpp + ${CMAKE_CURRENT_LIST_DIR}/runner/encoder.h + ${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp + ${CMAKE_CURRENT_LIST_DIR}/runner/runner.h + ${EXECUTORCH_ROOT}/extension/llm/sampler/sampler.cpp +) + +# build qnn t5 runner +add_executable(qnn_t5_runner ${_qnn_t5_runner__srcs}) +target_include_directories( + qnn_t5_runner PUBLIC ${_common_include_directories} + ${EXECUTORCH_ROOT}/extension/llm/tokenizers/include +) + + +target_link_libraries( + qnn_t5_runner + qnn_executorch_backend + executorch_core + extension_data_loader + extension_flat_tensor + extension_module + extension_tensor + gflags + tokenizers +) + +target_compile_options( + qnn_t5_runner PUBLIC ${_common_compile_options} +) +set_target_properties( + qnn_t5_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'" +) diff --git a/examples/qualcomm/oss_scripts/t5/qnn_t5_runner.cpp b/examples/qualcomm/oss_scripts/t5/qnn_t5_runner.cpp new file mode 100644 index 00000000000..d588da8dc1a --- /dev/null +++ b/examples/qualcomm/oss_scripts/t5/qnn_t5_runner.cpp @@ -0,0 +1,137 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * @file + * + * This tool can run t5 with Qualcomm AI Engine Direct. + * + */ + +#include +#include +#include +#include +#include +#include + +DEFINE_string( + model_path, + "t5_qnn.pte", + "t5 model serialized in flatbuffer format."); + +DEFINE_string( + tokenizer_model_path, + "tokenizer.model", + "The tokenizer is saved from T5Tokenize.save_pretrained for tokenizer."); +DEFINE_string( + input_list_path, + "input_list.txt", + "Input list storing file name of encoded results."); +DEFINE_int32( + seq_len, + 128, + "Maximum sequence length for the generated output. Defaults to use the model's `max_cache_size` attribute. Will be truncated to maximal cache size if larger than `max_cache_size`."); + +DEFINE_string( + output_folder_path, + "outputs", + "Executorch inference data output path."); + +std::vector>> parse_input_list_file( + const std::string& input_list_path) { + std::vector>> bufs; + std::ifstream input_list(input_list_path); + + auto split = [](std::string s, std::string delimiter) { + size_t pos_start = 0, pos_end, delim_len = delimiter.length(); + std::string token; + std::vector res; + + while ((pos_end = s.find(delimiter, pos_start)) != std::string::npos) { + token = s.substr(pos_start, pos_end - pos_start); + pos_start = pos_end + delim_len; + res.push_back(token); + } + res.push_back(s.substr(pos_start)); + return res; + }; + + if (!input_list.is_open()) { + ET_LOG(Error, "Unable to open file"); + return bufs; + } + + std::string file_path; + while (std::getline(input_list, file_path)) { + auto input_files = split(file_path, " "); + int num_inputs = input_files.size(); + if (num_inputs == 0) { + break; + } + + bufs.emplace_back(); + bufs.back().resize(num_inputs); + for (int input_index = 0; input_index < num_inputs; ++input_index) { + std::ifstream fin(input_files[input_index], std::ios::binary); + if (!fin.is_open()) { + ET_LOG( + Error, "Could not open file %s", input_files[input_index].c_str()); + continue; + } + + fin.seekg(0, std::ios::end); + size_t file_size = fin.tellg(); + fin.seekg(0, std::ios::beg); + + size_t num_tokens = file_size / sizeof(int64_t); + bufs.back()[input_index].resize(num_tokens); + + if (!fin.read( + reinterpret_cast(bufs.back()[input_index].data()), + file_size)) { + ET_LOG( + Error, "Could not read file %s", input_files[input_index].c_str()); + continue; + } + + fin.close(); + } + } + + input_list.close(); + return bufs; +} + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + std::vector>> multi_turns_input_buffers = + parse_input_list_file(FLAGS_input_list_path); + + for (int iter = 0; iter < multi_turns_input_buffers.size(); ++iter) { + std::vector bufs; + bufs.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char + auto callback = [&](const std::string& piece) { + for (const char c : piece) { + bufs.push_back(c); + } + }; + + example::Runner runner(FLAGS_model_path, FLAGS_tokenizer_model_path); + // generate tokens + runner.generate(FLAGS_seq_len, multi_turns_input_buffers[iter], callback); + auto output_file_name = + FLAGS_output_folder_path + "/output_" + std::to_string(iter) + ".txt"; + std::ofstream fout(output_file_name); + fout.write(bufs.data(), bufs.size()); + fout.close(); + } + + return 0; +} diff --git a/examples/qualcomm/oss_scripts/t5/runner/decoder.cpp b/examples/qualcomm/oss_scripts/t5/runner/decoder.cpp new file mode 100644 index 00000000000..2de2b72ba40 --- /dev/null +++ b/examples/qualcomm/oss_scripts/t5/runner/decoder.cpp @@ -0,0 +1,58 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +using executorch::aten::Tensor; +using executorch::extension::Module; +using executorch::extension::TensorPtr; +using executorch::runtime::Error; +using executorch::runtime::Result; + +namespace example { +T5Decoder::T5Decoder(const std::string& model_path) { + module_ = std::make_unique( + model_path, Module::LoadMode::MmapUseMlockIgnoreErrors); + ET_LOG(Info, "creating decoder module: model_path=%s", model_path.c_str()); +} + +bool T5Decoder::is_method_loaded() const { + return module_->is_method_loaded(kDecoderForwardName); +} + +Error T5Decoder::load() { + if (is_method_loaded()) { + return Error::Ok; + } + return module_->load_method(kDecoderForwardName); +} +Result T5Decoder::step( + TensorPtr& input_ids, + TensorPtr& attention_mask, + TensorPtr& encoder_hidden_states, + TensorPtr& encoder_attention_mask, + TensorPtr& cache_position) { + auto outputs_res = module_->execute( + kDecoderForwardName, + {input_ids, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + cache_position}); + ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); + ET_CHECK_MSG( + outputs_res.get().size() == 1, + "More then one output returned from executing decoder."); + ET_CHECK_MSG( + outputs_res.get()[0].isTensor(), + "Non Tensor Output returned from executing decoder"); + + // Return the logits tensor + return outputs_res.get()[0].toTensor(); +} +} // namespace example diff --git a/examples/qualcomm/oss_scripts/t5/runner/decoder.h b/examples/qualcomm/oss_scripts/t5/runner/decoder.h new file mode 100644 index 00000000000..4042c057b57 --- /dev/null +++ b/examples/qualcomm/oss_scripts/t5/runner/decoder.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace example { + +class T5Decoder { + public: + explicit T5Decoder(const std::string& model_path); + + bool is_method_loaded() const; + executorch::runtime::Error load(); + executorch::runtime::Result step( + executorch::extension::TensorPtr& input_ids, + executorch::extension::TensorPtr& attention_mask, + executorch::extension::TensorPtr& encoder_hidden_states, + executorch::extension::TensorPtr& encoder_attention_mask, + executorch::extension::TensorPtr& cache_position); + executorch::runtime::Result> method_names() { + return module_->method_names(); + } + executorch::runtime::Result get( + const std::string& method_name) { + return module_->get(method_name); + } + + executorch::runtime::Result> execute( + const std::string& method_name) { + return module_->execute(method_name); + } + + private: + std::unique_ptr module_; + static constexpr const char* kDecoderForwardName = "decoder"; +}; + +} // namespace example diff --git a/examples/qualcomm/oss_scripts/t5/runner/encoder.cpp b/examples/qualcomm/oss_scripts/t5/runner/encoder.cpp new file mode 100644 index 00000000000..487edec1d9d --- /dev/null +++ b/examples/qualcomm/oss_scripts/t5/runner/encoder.cpp @@ -0,0 +1,53 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +using executorch::aten::Tensor; +using executorch::extension::Module; +using executorch::extension::TensorPtr; +using executorch::runtime::Error; +using executorch::runtime::Result; +namespace example { +T5Encoder::T5Encoder(const std::string& model_path) { + module_ = std::make_unique( + model_path, Module::LoadMode::MmapUseMlockIgnoreErrors); + ET_LOG(Info, "creating encoder module: model_path=%s", model_path.c_str()); +} + +bool T5Encoder::is_method_loaded() const { + return module_->is_method_loaded(kEncoderForwardName); +} + +Error T5Encoder::load() { + if (is_method_loaded()) { + return Error::Ok; + } + return module_->load_method(kEncoderForwardName); +} + +Result T5Encoder::encode( + TensorPtr& input_ids, + executorch::extension::TensorPtr& prompt_attn_mask) { + auto outputs_res = + module_->execute(kEncoderForwardName, {input_ids, prompt_attn_mask}); + ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); + + const auto& outputs = outputs_res.get(); + + ET_CHECK_MSG( + outputs.size() == 1, + "More then one output returned from executing encoder."); + ET_CHECK_MSG( + outputs[0].isTensor(), + "Non Tensor Output returned from executing encoder"); + + // Return the hidden state tensor + return outputs[0].toTensor(); +} +} // namespace example diff --git a/examples/qualcomm/oss_scripts/t5/runner/encoder.h b/examples/qualcomm/oss_scripts/t5/runner/encoder.h new file mode 100644 index 00000000000..2b9731dddc8 --- /dev/null +++ b/examples/qualcomm/oss_scripts/t5/runner/encoder.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include +#include +#include +#include +#include +#include + +namespace example { + +class T5Encoder { + public: + explicit T5Encoder(const std::string& model_path); + + bool is_method_loaded() const; + executorch::runtime::Error load(); + executorch::runtime::Result encode( + executorch::extension::TensorPtr& input_ids, + executorch::extension::TensorPtr& prompt_attn_mask); + + private: + std::unique_ptr module_; + inline static const std::string kEncoderForwardName = "encoder"; +}; + +} // namespace example diff --git a/examples/qualcomm/oss_scripts/t5/runner/runner.cpp b/examples/qualcomm/oss_scripts/t5/runner/runner.cpp new file mode 100644 index 00000000000..ffccfb447c3 --- /dev/null +++ b/examples/qualcomm/oss_scripts/t5/runner/runner.cpp @@ -0,0 +1,234 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using executorch::extension::from_blob; +using executorch::extension::make_tensor_ptr; +using executorch::extension::llm::Sampler; +using executorch::extension::llm::time_in_ms; +using executorch::llm::kTopp; +using executorch::runtime::Error; +using executorch::runtime::Result; + +namespace example { +namespace { +static constexpr auto kEosId = "get_eos_id"; +static constexpr auto kMaxContextLen = "get_max_context_len"; +static constexpr auto kMaxHiddenSeqLen = "max_hidden_seq_length"; +} // namespace +Runner::Runner( + const std::string& model_path, + const std::string& tokenizer_model_path) + : tokenizer_model_path_(tokenizer_model_path) { + encoder_ = std::make_unique(model_path); + decoder_ = std::make_unique(model_path); + tokenizer_ = std::make_unique(); +} + +bool Runner::is_loaded() const { + return encoder_->is_method_loaded() && decoder_->is_method_loaded() && + tokenizer_->is_loaded() && sampler_; +} + +Error Runner::load() { + if (is_loaded()) { + return Error::Ok; + } + ET_CHECK_OK_OR_RETURN_ERROR(encoder_->load()); + ET_CHECK_OK_OR_RETURN_ERROR(decoder_->load()); + if (tokenizer_->load(tokenizer_model_path_) != tokenizers::Error::Ok) { + ET_LOG( + Error, + "Failed to load tokenizer with %s", + tokenizer_model_path_.c_str()); + return Error::Internal; + } + eos_ids_ = std::make_unique>( + std::unordered_set{tokenizer_->eos_tok()}); + + // create sampler + sampler_ = std::make_unique( + tokenizer_->vocab_size(), + 0, + kTopp, + static_cast(std::time(nullptr))); + + // Initialize metadata with default values + metadata_ = { + {kMaxContextLen, 128}, + {kMaxHiddenSeqLen, 384}, + }; + + // Read metadata from the model + auto method_names_result = decoder_->method_names(); + if (method_names_result.error() != Error::Ok) { + ET_LOG(Error, "Failed reading method names"); + return Error::Internal; + } + const auto method_names = method_names_result.get(); + + for (auto& [method_name, value] : metadata_) { + if (method_names.count(method_name)) { + auto get_result = decoder_->get(method_name); + + auto result = get_result.get(); + value = + get_result.get().toScalar().to(); + } else { + ET_LOG( + Info, + "Method %s not found, using the default value %" PRId64, + method_name.c_str(), + value); + } + ET_LOG(Info, "Metadata: %s = %" PRId64, method_name.c_str(), value); + } + + // Get EOS IDs if available + if (method_names.count(kEosId)) { + eos_ids_->clear(); + auto execute_result = decoder_->execute(kEosId); + if (execute_result.error() != Error::Ok) { + ET_LOG(Error, "Failed to execute %s", kEosId); + return Error::Internal; + } + for (const auto& eos_id : execute_result.get()) { + auto value = eos_id.toScalar().to(); + eos_ids_->emplace(value); + ET_LOG(Info, "eos_id = %" PRId64, value); + } + } + + return Error::Ok; +} + +uint64_t Runner::logits_to_token( + const executorch::aten::Tensor& logits_tensor) { + return sampler_->sample(logits_tensor.data_ptr()); +} + +Error Runner::generate( + int32_t seq_len, + std::vector>& inputs, + std::function token_callback) { + if (!is_loaded()) { + stats_.model_load_start_ms = time_in_ms(); + ET_CHECK_OK_OR_RETURN_ERROR(load()); + stats_.model_load_end_ms = time_in_ms(); + } + ET_CHECK_MSG(inputs.size() == 3, "The input size of t5 should be three."); + + ET_LOG(Info, "Start Encoding"); + stats_.encoder_inference_start_ms = time_in_ms(); + auto hidden_seq_len = static_cast(metadata_.at(kMaxHiddenSeqLen)); + executorch::extension::TensorPtr prompt_tokens = + from_blob(inputs[0].data(), {1, hidden_seq_len}, ScalarType::Long); + executorch::extension::TensorPtr prompt_attn_mask = + from_blob(inputs[1].data(), {1, hidden_seq_len}, ScalarType::Long); + + auto encoder_output = encoder_->encode(prompt_tokens, prompt_attn_mask); + + ET_CHECK_OK_OR_RETURN_ERROR(encoder_output.error()); + auto encoder_hidden_states_tensor_ptr = make_tensor_ptr(encoder_output.get()); + stats_.encoder_inference_end_ms = time_in_ms(); + auto max_seq_len = metadata_.at(kMaxContextLen); + + seq_len = (seq_len > 0 && seq_len <= max_seq_len) ? seq_len : max_seq_len; + + int64_t pos = 0; + num_generated_token_ = 0; + + // use decoder_input_id as first token + ET_CHECK_MSG(!inputs[2].empty(), "decoder_input_ids is empty."); + uint64_t prev_token = inputs[2][0], cur_token = prev_token; + + ET_LOG(Info, "Start Decoding"); + std::vector output_token_ids; + std::vector attention_mask_data(max_seq_len, -255.0); + stats_.decoder_inference_start_ms = time_in_ms(); + while (pos < seq_len) { + auto decoder_input_ids_tensor_ptr = + from_blob(&cur_token, {1, 1}, ScalarType::Long); + attention_mask_data[pos] = 0; + auto attention_mask_tensor_ptr = from_blob( + attention_mask_data.data(), + {1, 1, 1, static_cast(max_seq_len)}, + ScalarType::Float); + auto pos_tensor_ptr = from_blob(&pos, {1}, ScalarType::Long); + Result logits = decoder_->step( + decoder_input_ids_tensor_ptr, + attention_mask_tensor_ptr, + encoder_hidden_states_tensor_ptr, + prompt_attn_mask, + pos_tensor_ptr); + + prev_token = cur_token; + cur_token = logits_to_token(logits.get()); + ++pos; + output_token_ids.push_back(cur_token); + + if (token_callback) { + token_callback( + ET_UNWRAP_TOKENIZER(tokenizer_->decode(prev_token, cur_token))); + } + if (eos_ids_->count(cur_token) > 0) { + ET_LOG(Info, "\nReached to the end of generation"); + break; + } + } + stats_.decoder_inference_end_ms = time_in_ms(); + if (pos == seq_len) { + ET_LOG(Info, "\nSequence length (%i tokens) reached!", seq_len); + } + num_generated_token_ = pos; + print_performance(); + return Error::Ok; +} + +Error Runner::print_performance() { + ET_LOG(Info, "\tTotal Generated token:\t\t\t\t%ld", num_generated_token_); + + ET_LOG( + Info, + "\tModel Load Time:\t\t\t\t%f (seconds)", + ((double)(stats_.model_load_end_ms - stats_.model_load_start_ms) / + stats_.SCALING_FACTOR_UNITS_PER_SECOND)); + + ET_LOG( + Info, + "\tEncoding Time:\t\t\t\t\t%f (seconds)", + ((double)(stats_.encoder_inference_end_ms - + stats_.encoder_inference_start_ms) / + stats_.SCALING_FACTOR_UNITS_PER_SECOND)); + + ET_LOG( + Info, + "\tDecoding Time:\t\t\t%f (seconds)", + ((double)(stats_.decoder_inference_end_ms - + stats_.decoder_inference_start_ms) / + stats_.SCALING_FACTOR_UNITS_PER_SECOND)); + + ET_LOG( + Info, + "\tAverage Decoding Time:\t\t\t%f (seconds)", + ((double)((stats_.decoder_inference_end_ms - + stats_.decoder_inference_start_ms) / + num_generated_token_) / + (stats_.SCALING_FACTOR_UNITS_PER_SECOND))); + + return Error::Ok; +} + +} // namespace example diff --git a/examples/qualcomm/oss_scripts/t5/runner/runner.h b/examples/qualcomm/oss_scripts/t5/runner/runner.h new file mode 100644 index 00000000000..9c8d77b50e8 --- /dev/null +++ b/examples/qualcomm/oss_scripts/t5/runner/runner.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// A simple t5 runner that includes preprocessing and post processing +// logic. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace example { + +class Runner { + public: + explicit Runner( + const std::string& model_path, + const std::string& tokenizer_model_path); + + struct Stats { + // Scaling factor for timestamps - in this case, we use ms. + const long SCALING_FACTOR_UNITS_PER_SECOND = 1000; + // Time stamps for the different stages of the execution + // model_load_start_ms: Model loading time + long model_load_start_ms; + long model_load_end_ms; + + // encoder inference time + long encoder_inference_start_ms = 0; + long encoder_inference_end_ms = 0; + + // decoder inference time + long decoder_inference_start_ms = 0; + long decoder_inference_end_ms = 0; + }; + + bool is_loaded() const; + executorch::runtime::Error load(); + executorch::runtime::Error generate( + int32_t seq_len, + std::vector>& inputs, + std::function token_callback = {}); + + private: + executorch::runtime::Error print_performance(); + uint64_t logits_to_token(const executorch::aten::Tensor& logits_tensor); + // model + std::unique_ptr encoder_; + std::unique_ptr decoder_; + std::unique_ptr tokenizer_; + std::unique_ptr sampler_; + std::string tokenizer_model_path_; + + std::unordered_map metadata_; + std::unique_ptr> eos_ids_; + + int64_t num_generated_token_ = 0; + Stats stats_; +}; + +} // namespace example diff --git a/examples/qualcomm/oss_scripts/t5/t5.py b/examples/qualcomm/oss_scripts/t5/t5.py new file mode 100644 index 00000000000..1b8ea1b1665 --- /dev/null +++ b/examples/qualcomm/oss_scripts/t5/t5.py @@ -0,0 +1,361 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import getpass +import json +import os +import subprocess +from multiprocessing.connection import Client + +import torch + +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset +from executorch.backends.qualcomm.utils.utils import ( + generate_htp_compiler_spec, + generate_qnn_executorch_compiler_spec, + to_edge_transform_and_lower_to_qnn, +) +from executorch.devtools.backend_debug import print_delegation_info +from executorch.examples.qualcomm.oss_scripts.t5.t5_model import ( + CustomT5Stack, + Seq2SeqLMDecoderExportableModuleWithStaticCache, + Seq2SeqLMEncoderExportableModule, + Seq2SeqLMExportableModulePipeline, +) +from executorch.examples.qualcomm.utils import ( + evaluate_squad, + get_seq2seq_dataset_from_squad_csv, + make_quantizer, + replace_module_with_custom_class, + setup_common_args_and_variables, + SimpleADB, +) +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer +from transformers.models.t5.modeling_t5 import T5Stack + +PTE_FILE_NAME = "t5_qnn" +ENCODER = "encoder" +DECODER = "decoder" + + +class T5: + def __init__( + self, + model: AutoModelForSeq2SeqLM, + tokenizer: AutoTokenizer, + batch_size=1, + max_hidden_seq_length=4096, + max_cache_length=1024, + ): + self.encoder = ( + Seq2SeqLMEncoderExportableModule( + model.get_encoder(), max_hidden_seq_length=max_hidden_seq_length + ) + .to("cpu") + .eval() + ) + self.decoder = ( + Seq2SeqLMDecoderExportableModuleWithStaticCache( + model, + max_hidden_seq_length=max_hidden_seq_length, + max_static_cache_length=max_cache_length, + batch_size=batch_size, + ) + .to("cpu") + .eval() + ) + + # Source transformation + for model in [self.encoder, self.decoder]: + replace_module_with_custom_class( + model, + target_class=T5Stack, + custom_class=CustomT5Stack, + extra_custom_kwargs={ + "max_hidden_seq_length": max_hidden_seq_length, + "max_cache_length": max_cache_length, + }, + ) + + # Runner pipeline + self.pipe = Seq2SeqLMExportableModulePipeline( + tokenizer, + model.config, + max_hidden_seq_length=max_hidden_seq_length, + max_seq_len=max_cache_length, + ) + + self.exported_encoder = None + self.exported_decoder = None + self.quant_dtype = None + + def quantize(self, inputs, quant_dtype, targets=None, metrics=None): + assert quant_dtype is not None, "quant_dtype must be specified" + self.quant_dtype = quant_dtype + + with torch.no_grad(): + + # Export Modules + self.exported_encoder = torch.export.export( + self.encoder, self.encoder.get_example_inputs(), strict=True + ).module() + self.exported_decoder = torch.export.export( + self.decoder, self.decoder.get_example_inputs(), strict=True + ).module() + + # Quantization + print(f"Applying quantization with dtype: {quant_dtype}...") + quantizer = make_quantizer( + per_channel_linear=True, + quant_dtype=quant_dtype, + ) + + self.exported_encoder = prepare_pt2e(self.exported_encoder, quantizer) + self.exported_decoder = prepare_pt2e(self.exported_decoder, quantizer) + + # Calibration + self.pipe(self.exported_encoder, self.exported_decoder, inputs) + + self.exported_encoder = convert_pt2e(self.exported_encoder) + self.exported_decoder = convert_pt2e(self.exported_decoder) + + if targets is not None and metrics is not None: + print(f"Metrics provided for validation: {metrics.__name__}") + self.pipe.validate( + self.exported_encoder, + self.exported_decoder, + inputs, + targets, + metrics, + ) + else: + print("No targets or metrics provided. Skipping validation step.") + + def lowering_modules( + self, + workspace, + use_fp16=False, + soc_model=QcomChipset.SM8650, + skip_node_id_set=None, + skip_node_op_set=None, + verbose=True, + ): + graph_names = [ENCODER, DECODER] + + if not self.exported_encoder or not self.exported_decoder: + modules = [ + self.encoder, + self.decoder, + ] + else: + modules = [ + self.exported_encoder, + self.exported_decoder, + ] + + backend_options = generate_htp_compiler_spec(use_fp16=use_fp16) + compile_spec = generate_qnn_executorch_compiler_spec( + soc_model=soc_model, + backend_options=backend_options, + ) + edge_prog_mgr = to_edge_transform_and_lower_to_qnn( + dict(zip(graph_names, modules)), + dict( + zip( + graph_names, + [ + self.encoder.get_example_inputs(), + self.decoder.get_example_inputs(), + ], + ) + ), + compile_spec, + constant_methods=self.decoder.get_metadata(), + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + skip_mutable_buffer=False, + ) + + executorch_config = ExecutorchBackendConfig( + memory_planning_pass=MemoryPlanningPass( + alloc_graph_input=True, + alloc_graph_output=True, + ), + extract_delegate_segments=True, + ) + + if verbose: + for graph_name in graph_names: + print_delegation_info( + edge_prog_mgr.exported_program(graph_name).graph_module + ) + + exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config) + with open(f"{workspace}/{PTE_FILE_NAME}.pte", "wb") as file: + exec_prog_mgr.write_to_file(file) + + +def main(args): + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + if not args.compile_only and args.device is None: + raise RuntimeError( + "device serial is required if not compile only. " + "Please specify a device serial by -s/--device argument." + ) + + data_size = 100 + max_hidden_seq_length = 384 + max_cache_length = 512 + + tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") + model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-small").eval() + inputs, targets, input_list = get_seq2seq_dataset_from_squad_csv( + args.dataset, + tokenizer, + data_size, + max_hidden_seq_length=max_hidden_seq_length, + shuffle=False, + ) + + if not args.pre_gen_pte: + t5 = T5( + model, + tokenizer, + max_hidden_seq_length=max_hidden_seq_length, + max_cache_length=max_cache_length, + ) + quant_dtype = QuantDtype.use_16a8w + t5.quantize(inputs, quant_dtype) + t5.lowering_modules( + args.artifact, + soc_model=getattr(QcomChipset, args.model), + use_fp16=True if quant_dtype is None else False, + ) + + if args.compile_only: + return + + pte_path = ( + f"{args.pre_gen_pte}/{PTE_FILE_NAME}" + if args.pre_gen_pte + else f"{args.artifact}/{PTE_FILE_NAME}" + ) + ".pte" + _, _, spiece_model, _, _ = tokenizer.save_pretrained(args.artifact) + + workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/{PTE_FILE_NAME}" + + outputs = [] + + def post_process(): + for i in range(len(inputs)): + with open(f"{args.artifact}/outputs/output_{i}.txt", "r") as f: + outputs.append(f.read()) + + runner_args = " ".join( + [ + f"--tokenizer_model_path {os.path.basename(spiece_model)}", + f"--model_path {PTE_FILE_NAME}.pte", + f"--seq_len {max_cache_length}", + "--output_folder_path outputs", + ] + ) + if args.enable_x86_64: + # x86 emulator is intended for CI and not performance. + qnn_sdk = os.getenv("QNN_SDK_ROOT") + target = "x86_64-linux-clang" + runner_cmd = " ".join( + [ + f"export LD_LIBRARY_PATH={qnn_sdk}/lib/{target}/:{args.build_folder}/lib &&", + f"./{args.build_folder}/examples/qualcomm/oss_scripts/t5/qnn_t5_runner", + runner_args, + ] + ) + subprocess.run( + runner_cmd, + shell=True, + executable="/bin/bash", + capture_output=True, + ) + post_process() + else: + runner_cmd = " ".join( + [ + f"cd {workspace} &&", + "./qnn_t5_runner", + runner_args, + ] + ) + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=pte_path, + workspace=workspace, + device_id=args.device, + host_id=args.host, + soc_model=args.model, + runner="examples/qualcomm/oss_scripts/t5/qnn_t5_runner", + ) + adb.push( + inputs=inputs, + input_list=input_list, + files=[spiece_model], + ) + adb.execute(custom_runner_cmd=runner_cmd) + adb.pull(output_path=args.artifact, callback=post_process) + + result = Seq2SeqLMExportableModulePipeline.evaluate_with_ground_truth( + tokenizer, outputs, targets, evaluate_squad + ) + + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"f1": result["f1"]})) + else: + print(f"F1 score: {result['f1']}") + + +if __name__ == "__main__": + parser = setup_common_args_and_variables() + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts and output by this example. Default ./t5", + default="./t5", + type=str, + ) + parser.add_argument( + "--pre_gen_pte", + help="Run the pre-generated t5 in the given directory.", + type=str, + ) + parser.add_argument( + "-d", + "--dataset", + help=( + "path to the validation text. " + "e.g. --dataset SQuAD-v1.1.csv " + "for https://www.kaggle.com/datasets/akashdesarda/squad-v11?select=SQuAD-v1.1.csv" + ), + type=str, + required=True, + ) + + args = parser.parse_args() + try: + main(args) + except Exception as e: + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"Error": str(e)})) + else: + raise Exception(e) diff --git a/examples/qualcomm/oss_scripts/t5/t5_model.py b/examples/qualcomm/oss_scripts/t5/t5_model.py new file mode 100644 index 00000000000..0593feaa8b8 --- /dev/null +++ b/examples/qualcomm/oss_scripts/t5/t5_model.py @@ -0,0 +1,632 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional + +import torch +from tqdm import tqdm +from transformers import AutoTokenizer, T5Config +from transformers.cache_utils import ( + Cache, + DynamicCache, + EncoderDecoderCache, + StaticCache, +) +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from transformers.models.t5.modeling_t5 import T5Attention, T5Stack +from transformers.utils import is_torchdynamo_compiling, logging + +logger = logging.get_logger(__name__) + + +# Copy from transformers/models/t5/modeling_t5.py (transformers=4.47.1) +class CustomT5Stack(T5Stack): + def __init__( + self, + config, + embed_tokens=None, + max_hidden_seq_length=4096, + max_cache_length=1024, + ): + super().__init__(config, embed_tokens) + + # ====================Qualcomm Changed================================= + # Customized position bias computation: + # Since the calculation in `T5Attention._relative_position_bucket` is not QNN-friendly, + # we precompute the relative position buckets as constant tensors during initialization. + # For the encoder: use the precomputed `encoder_self_attn_position_bias`. + # For the decoder: use the precomputed `decoder_self_attn_position_bias`. + + self.max_hidden_seq_length = max_hidden_seq_length + self.max_cache_length = max_cache_length + + # Create relative position table for encoder + encoder_self_attn_relative_position_bucket = ( + T5Attention._relative_position_bucket( + torch.arange(max_hidden_seq_length)[None, :] + - torch.arange(max_hidden_seq_length)[:, None], + bidirectional=(not self.is_decoder), + num_buckets=config.relative_attention_num_buckets, + max_distance=config.relative_attention_max_distance, + ) + ) + self.register_buffer( + "encoder_self_attn_position_bias", + encoder_self_attn_relative_position_bucket, + ) + + # Create relative position table for decoder + self_attn_relative_position_bucket = T5Attention._relative_position_bucket( + torch.arange(max_cache_length)[None, :] + - torch.arange(max_cache_length)[:, None], + bidirectional=(not self.is_decoder), + num_buckets=config.relative_attention_num_buckets, + max_distance=config.relative_attention_max_distance, + ) + self.register_buffer( + "decoder_self_attn_position_bias", + self_attn_relative_position_bucket, + ) + # ======================================================================== + + def forward( # noqa: C901 + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + cache_position=None, + ): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + if self.embed_tokens is None: + raise ValueError( + "You have to initialize the model with valid token embeddings" + ) + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + if use_cache is True: + if not self.is_decoder: + raise ValueError( + f"`use_cache` can only be set to `True` if {self} is used as a decoder" + ) + + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if self.is_decoder and (use_cache or past_key_values is not None): + if isinstance(past_key_values, Cache) and not isinstance( + past_key_values, EncoderDecoderCache + ): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None + + past_key_values_length = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, + past_key_values_length + seq_length, + device=inputs_embeds.device, + ) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones( + batch_size, mask_seq_length, device=inputs_embeds.device + ) + + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + ( + past_key_values.self_attention_cache + if past_key_values is not None + else None + ), + output_attentions, + ) + elif attention_mask is not None: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min + else: + causal_mask = None + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = ( + encoder_hidden_states.size() + ) + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long + ) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask( + cross_attn_head_mask, self.config.num_layers + ) + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + + # ====================Qualcomm Changed================================= + # The bias is indexed by cache_position to select the correct positions for the current step. + if self.is_decoder: + # For decoder, use the decoder's relative position bias table. + position_bias = ( + self.block[0] + .layer[0] + .SelfAttention.relative_attention_bias( + self.decoder_self_attn_position_bias[cache_position] + ) + .permute([2, 0, 1]) + .unsqueeze(0) + ) + else: + # For encoder, use the encoder's relative position bias table. + position_bias = ( + self.block[0] + .layer[0] + .SelfAttention.relative_attention_bias( + self.encoder_self_attn_position_bias[cache_position] + ) + .permute([2, 0, 1]) + .unsqueeze(0) + ) + position_bias = position_bias[:, :, -seq_length:, :] + if self.is_decoder: + position_bias = ( + position_bias + causal_mask[:, :, :, : self.max_cache_length] + ) + else: + position_bias = position_bias + causal_mask[:, :, :, :seq_length] + + # For cross-attention in decoder, precompute encoder-decoder position bias as zeros and add encoder attention mask. + encoder_decoder_position_bias = None + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = torch.zeros( + (1, self.config.num_heads, seq_length, self.max_hidden_seq_length), + dtype=encoder_extended_attention_mask.dtype, + ) + encoder_decoder_position_bias = ( + encoder_decoder_position_bias + + encoder_extended_attention_mask[:, :, :, : self.max_hidden_seq_length] + ) + # ======================================================================== + + hidden_states = self.dropout(inputs_embeds) + + for i, layer_module in enumerate(self.block): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if causal_mask is not None: + causal_mask = causal_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to( + hidden_states.device + ) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = ( + encoder_extended_attention_mask.to(hidden_states.device) + ) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to( + hidden_states.device + ) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to( + hidden_states.device + ) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer_module.forward, + hidden_states, + causal_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + use_cache, + output_attentions, + return_dict, + cache_position, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=causal_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, next_decoder_cache = layer_outputs[:2] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +class Seq2SeqLMEncoderExportableModule(torch.nn.Module): + def __init__(self, encoder_model, max_hidden_seq_length): + super().__init__() + self.encoder = encoder_model + self.max_hidden_seq_length = max_hidden_seq_length + + def get_example_inputs(self): + max_hidden_seq_length = self.max_hidden_seq_length + input_ids = torch.randint(0, max_hidden_seq_length, (1, max_hidden_seq_length)) + attn_mask = torch.randint(0, max_hidden_seq_length, (1, max_hidden_seq_length)) + return input_ids, attn_mask + + def forward(self, input_ids, attn_mask): + encoder_outputs = self.encoder( + input_ids, + attn_mask, + return_dict=True, + ) + + return encoder_outputs.last_hidden_state + + +class Seq2SeqLMDecoderExportableModuleWithStaticCache(torch.nn.Module): + def __init__( + self, + model, + max_hidden_seq_length, + max_static_cache_length, + batch_size, + ): + super().__init__() + + # Get the decoder component + self.decoder = model.get_decoder() + self.proj_out = model.lm_head + self.config = model.config + self.max_hidden_seq_length = max_hidden_seq_length + self.max_static_cache_length = max_static_cache_length + + # Initialize static cache + self.static_cache = StaticCache( + config=self.config, + max_batch_size=batch_size, + max_cache_len=max_static_cache_length, + device="cpu", + dtype=torch.float32, + ) + + # Register cache buffers to make them exportable + for i in range(len(self.static_cache.key_cache)): + self.register_buffer( + f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False + ) + self.register_buffer( + f"value_cache_{i}", + self.static_cache.value_cache[i], + persistent=False, + ) + + def get_example_inputs(self): + max_hidden_seq_length = self.max_hidden_seq_length + hidden_size = self.config.d_model + decoder_input_ids = torch.tensor([[0]], dtype=torch.long) + min_dtype = torch.finfo(torch.float32).min + attn_mask = torch.full( + (1, 1, 1, self.max_static_cache_length), + fill_value=min_dtype, + dtype=torch.float32, + ) + attn_mask[..., 0] = 0 + encoder_hidden_states = torch.randn(1, self.max_hidden_seq_length, hidden_size) + encoder_attn_mask = torch.ones((1, max_hidden_seq_length), dtype=torch.long) + cache_position = torch.tensor([0], dtype=torch.long) + return ( + decoder_input_ids, + attn_mask, + encoder_hidden_states, + encoder_attn_mask, + cache_position, + ) + + def forward( + self, + decoder_input_ids, + attn_mask, + encoder_hidden_states, + encoder_attention_mask, + cache_position, + ): + # Get outputs from decoder + outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=attn_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=self.static_cache, + use_cache=True, + cache_position=cache_position, + ) + sequence_output = outputs[0] + if self.config.tie_word_embeddings: + sequence_output = sequence_output * (self.config.d_model**-0.5) + + # Apply linear projection (lm head) to obtain logits + logits = self.proj_out(sequence_output) + return logits + + def get_metadata(self): + return { + "get_eos_id": getattr(self.config, "eos_token_id", None), + "get_max_context_len": self.max_static_cache_length, + "max_hidden_seq_length": self.max_hidden_seq_length, + } + + +class Seq2SeqLMExportableModulePipeline(torch.nn.Module): + def __init__( + self, + tokenizer: AutoTokenizer, + config: T5Config, + max_hidden_seq_length=4096, + max_seq_len=1024, + ): + super().__init__() + self.tokenizer = tokenizer + self.config = config + self.max_seq_len = max_seq_len + + self.max_hidden_seq_length = max_hidden_seq_length + + def __call__( + self, + encoder, + decoder, + dataset, + ): + self.validate(encoder, decoder, dataset, None, None) + + def validate( + self, + encoder, + decoder, + dataset, + targets: Optional[List[torch.Tensor]] = None, + metrics: Optional[callable] = None, + ): + predicted_texts = [] + target_texts = [] + + with torch.no_grad(): + for i, data in tqdm(enumerate(dataset)): + + token_list = self.generate(encoder, decoder, data) + + if targets is None: + continue + + predicted_texts.append( + self.tokenizer.decode(token_list[0], skip_special_tokens=True) + ) + target_texts.append( + self.tokenizer.decode(targets[i], skip_special_tokens=True) + ) + print(f"Show {i}/{len(dataset)} result:") + print(f"\tPrediction: {predicted_texts[i]}") + print(f"\tTarget: {target_texts[i]}") + + if targets is None or metrics is None: + print("No targets or metrics provided for validation.") + else: + results = metrics(predicted_texts, target_texts) + print("F1 Score:", results["f1"]) + + def generate(self, encoder, decoder, data): + prompt_token_ids, encoder_attn_mask, decoder_input_ids = data + + min_dtype = torch.finfo(torch.float32).min + attn_mask = torch.full( + (1, 1, 1, self.max_seq_len), fill_value=min_dtype, dtype=torch.float32 + ) + attn_mask[..., 0] = 0 + + with torch.no_grad(): + # Run encoder + encoder_output = encoder(prompt_token_ids, encoder_attn_mask) + generated_ids = [0] + + # Generate tokens one by one + for i in range(self.max_seq_len - 1): + # Run decoder for next token prediction + logits = decoder( + decoder_input_ids, + attn_mask, + encoder_output, + encoder_attn_mask, + torch.tensor([i], dtype=torch.long), + ) + + # Get next token + next_token = torch.argmax(logits[:, -1, :], dim=-1).item() + generated_ids.append(next_token) + + # Update input for next iteration + decoder_input_ids = torch.tensor([[next_token]], dtype=torch.long) + + # Check if EOS token + if next_token == self.config.eos_token_id: + break + + # update attn_mask + attn_mask[..., i] = 0 + + return [generated_ids] + + @staticmethod + def evaluate_with_ground_truth( + tokenizer: AutoTokenizer, + predicts: str, + targets: Optional[List[torch.Tensor]], + metrics: Optional[callable], + ): + predicted_texts = [] + target_texts = [] + for i, (pred, tar) in tqdm(enumerate(zip(predicts, targets))): + + predicted_texts.append(pred) + target_texts.append(tokenizer.decode(tar, skip_special_tokens=True)) + print(f"Show {i}/{len(predicts)} result:") + print(f"\tPrediction: {pred}") + print(f"\tTarget: {target_texts[i]}") + results = metrics(predicted_texts, target_texts) + print("F1 Score:", results["f1"]) + + return results diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index e70510b0b70..1a2d9e4f26b 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -6,21 +6,24 @@ # TODO: reenable pyre after fixing the issues # pyre-ignore-all-errors - import argparse +import csv +import inspect import os +import random import shutil import subprocess import sys import tempfile from pathlib import Path -from typing import Callable, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple import numpy as np import torch import torchao +import transformers from executorch.backends.qualcomm.quantizer.quantizer import ( ModuleQConfig, QnnQuantizer, @@ -284,6 +287,74 @@ def make_quantizer( return quantizer +def replace_module_with_custom_class( + model: torch.nn.Module, + target_class: torch.nn.Module, + custom_class: torch.nn.Module, + strict: bool = False, + extra_custom_kwargs: Optional[Dict] = None, +): + """ + Recursively replaces all instances of `target_class` in `model` with `custom_class`. + + Args: + model (torch.nn.Module): The root module to search within. + target_class (type): The class to be replaced. + custom_class (type): The class to replace with. + strict (bool): Whether to strictly enforce that the keys in `state_dict` match the model. + extra_custom_kwargs: Extra keyword arguments to override or extend the constructor args. + + Example: + >>> class MyDecoder(Decoder): + ... def __init__(self, ...) + ... super().__init__() + ... freqs_cos, freqs_sin = precompute_freqs_cis(...) + ... self.register_buffer("freqs_cos", freqs_cos) + ... self.register_buffer("freqs_sin", freqs_sin) + ... + ... def forward(self, x): + ... .... + >>> model = Decoder() + >>> replace_module_with_custom_class(model, Decoder, MyDecoder) + """ + + def extract_init_args_from_instance(instance): + init_signature = inspect.signature(instance.__init__) + init_params = [ + param + for param in init_signature.parameters.values() + if param.name != "self" + ] + + extracted_args = {} + for param in init_params: + name = param.name + if hasattr(instance, name): + extracted_args[name] = getattr(instance, name) + elif param.default is not inspect.Parameter.empty: + extracted_args[name] = param.default + + return extracted_args + + if extra_custom_kwargs is None: + extra_custom_kwargs = {} + + for name, child in model.named_children(): + if isinstance(child, target_class): + state_dict = child.state_dict() + + original_args = extract_init_args_from_instance(child) + new_module = custom_class(**{**original_args, **extra_custom_kwargs}) + new_module.load_state_dict(state_dict, strict=strict) + new_module.eval() + + setattr(model, name, new_module) + else: + replace_module_with_custom_class( + child, target_class, custom_class, strict, extra_custom_kwargs + ) + + # TODO: refactor to support different backends def build_executorch_binary( model, # noqa: B006 @@ -452,6 +523,32 @@ def class_agnostic_mIoU(predictions, targets): return total_iou / len(predictions) +def evaluate_squad(predicted_texts: List[str], target_texts: List[str]): + import evaluate + + squad_metric = evaluate.load("squad") + + predictions = [] + references = [] + + for i, (pred, target) in enumerate(zip(predicted_texts, target_texts)): + predictions.append({"id": str(i), "prediction_text": pred.strip()}) + references.append( + { + "id": str(i), + "answers": { + "text": [target.strip()], + "answer_start": [0], # answer_start could be dummy + }, + } + ) + + results = squad_metric.compute(predictions=predictions, references=references) + results["f1"] /= 100 + results["exact_match"] /= 100 + return results + + def get_imagenet_dataset( dataset_path, data_size, image_shape, crop_size=None, shuffle=True ): @@ -489,14 +586,9 @@ def get_data_loader(): def get_masked_language_model_dataset(dataset_path, tokenizer, data_size, shuffle=True): - import random - - import transformers - - from torch.utils.data import Dataset def get_data_loader(): - class MaskedSentencesDataset(Dataset): + class MaskedSentencesDataset(torch.utils.data.Dataset): def __init__(self, dataset_path, tokenizer, data_size) -> None: self.data_size = data_size self.dataset = self._get_val_dataset(dataset_path, data_size, tokenizer) @@ -550,6 +642,108 @@ def __len__(self): return inputs, targets, input_list +def get_seq2seq_dataset_from_squad_csv( # noqa: C901 + dataset_path, + tokenizer, + data_size, + max_hidden_seq_length=384, + shuffle=True, +): + + def get_data_loader(max_hidden_seq_length): + class SquadSeq2SeqDataset(torch.utils.data.Dataset): + def __init__( + self, + dataset_path, + tokenizer, + data_size, + max_hidden_seq_length, + ): + self.max_hidden_seq_length = max_hidden_seq_length + self.tokenizer = tokenizer + self.samples = self._load_and_process(dataset_path, data_size) + + def _load_and_process(self, path, max_samples): + with open(path, "r", encoding="utf-8") as f: + reader = csv.DictReader(f) + rows = list(reader) + if shuffle: + random.shuffle(rows) + samples = [] + for row in rows: + question = row["question"].strip() + context = row["context"].strip() + answer = row["answer"].strip() + if not question or not context or not answer: + continue + input_text = f"question: {question} context: {context}" + target_text = answer + samples.append((input_text, target_text)) + if len(samples) >= max_samples: + break + return samples + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + input_text, target_text = self.samples[idx] + model_input = tokenizer( + input_text, + truncation=True, + padding="max_length", + max_length=self.max_hidden_seq_length, + return_tensors="pt", + ) + + label = tokenizer( + target_text, + truncation=True, + padding="max_length", + max_length=64, + return_tensors="pt", + ) + return { + "input_ids": model_input["input_ids"].squeeze(0), + "attention_mask": model_input["attention_mask"].squeeze(0), + "decoder_input_ids": torch.tensor([0], dtype=torch.long), + "labels": label["input_ids"].squeeze(0), + } + + dataset = SquadSeq2SeqDataset( + dataset_path, tokenizer, data_size, max_hidden_seq_length + ) + collator = transformers.DataCollatorForSeq2Seq(tokenizer) + return torch.utils.data.DataLoader( + dataset, batch_size=1, shuffle=shuffle, collate_fn=collator + ) + + inputs, targets, input_list = [], [], "" + data_loader = get_data_loader(max_hidden_seq_length) + for idx, batch in enumerate(data_loader): + if len(inputs) >= data_size: + break + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"] + decoder_input_ids = batch["decoder_input_ids"] + labels = batch["labels"][0] + + if (labels != -100).sum().item() == 0: + continue + + inputs.append( + ( + input_ids.to(torch.long), + attention_mask.to(torch.long), + decoder_input_ids, + ) + ) + targets.append(labels) + input_list += f"input_{idx}_0.raw input_{idx}_1.raw input_{idx}_2.raw\n" + + return inputs, targets, input_list + + def setup_common_args_and_variables(): parser = argparse.ArgumentParser()