diff --git a/backends/qualcomm/_passes/lift_constant_scalar_operands.py b/backends/qualcomm/_passes/lift_constant_scalar_operands.py index 6704ca6e0dc..dc9592e415b 100644 --- a/backends/qualcomm/_passes/lift_constant_scalar_operands.py +++ b/backends/qualcomm/_passes/lift_constant_scalar_operands.py @@ -59,6 +59,7 @@ class TensorOpInfo: SKIP_LIFT_OPS = { aten.full_like.default, + aten.full.default, aten.arange.start_step, aten.arange.default, aten.scalar_tensor.default, diff --git a/backends/qualcomm/builders/op_index_put.py b/backends/qualcomm/builders/op_index_put.py index 9972066e165..c3c42ed483a 100644 --- a/backends/qualcomm/builders/op_index_put.py +++ b/backends/qualcomm/builders/op_index_put.py @@ -88,11 +88,15 @@ def define_node( # Need to reconstruct the index tensor. # E.g., based on ScatterND Op Def in QNN Docs. - # Given that - # shape of input: [1, 12, 1024, 64] - # indicies_node: [None, None, aten__to_copy_default_1] - # shape of aten__to_copy_default_1: [1] - # The shape of index tensor should be [1, 12, 1, 3] + # Torch: + # Given that + # shape of input: [1, 12, 1024, 64] + # indicies_node: [None, None, aten__to_copy_default_1] + # shape of aten__to_copy_default_1: [1] + # QNN: + # Index tensor: + # Shape: [1, 12, 1, 3] + # Value: [[[0,0,x]],[[0,1,x]],...,[[0,11,x]]] # The index tensor is treated as 4-dimensional tensor of 3-tuples, # where each 3-tuple is a partial-index into input # Reference code for QNN ScatterNd: diff --git a/backends/qualcomm/builders/op_linear.py b/backends/qualcomm/builders/op_linear.py index a73633ac229..d5ac153b8d1 100644 --- a/backends/qualcomm/builders/op_linear.py +++ b/backends/qualcomm/builders/op_linear.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import warnings from typing import Dict import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper @@ -70,13 +69,6 @@ def define_node( if len(node.args) >= 3: bias_node = self.get_node(node.args[2]) - # TODO remove this when qnn sdk support - if QCOM_SCALES in bias_node.meta.get(QCOM_QUANT_ATTRS, {}): - warnings.warn( - f"[QNN Delegate Op Builder]: Fallback linear bias, {bias_node}. per channel bias quantization is not support yet.", - stacklevel=1, - ) - bias_tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC bias_tensor = get_parameter(bias_node, self.edge_program) # if bias_node is getitem diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 8be05d46688..9068347cd9e 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -910,9 +910,10 @@ def forward(self, x): class IndexCopy(torch.nn.Module): - def __init__(self, skip_mutable_buffer=False): + def __init__(self, copy_dim=1, skip_mutable_buffer=False): super().__init__() self.skip_mutable_buffer = skip_mutable_buffer + self.copy_dim = copy_dim self.register_buffer( "k_cache", torch.zeros((1, 1024, 12, 64), dtype=torch.float32), @@ -921,7 +922,7 @@ def __init__(self, skip_mutable_buffer=False): def forward(self, input_pos, k_val): k_out = self.k_cache - k_out.index_copy_(1, input_pos, k_val) + k_out.index_copy_(self.copy_dim, input_pos, k_val) return k_out + 0 diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 4a0edaf471d..22e0d1cc219 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -622,19 +622,59 @@ def test_qnn_backend_index(self): def test_qnn_backend_index_copy(self): test_comb = [ { - QCOM_MODULE: IndexCopy(skip_mutable_buffer=False), # noqa: F405 + QCOM_MODULE: IndexCopy( # noqa: F405 + copy_dim=1, skip_mutable_buffer=False + ), QCOM_SAMPLE_INPUTS: ( torch.tensor([2], dtype=torch.int64), torch.randn([1, 1, 12, 64]), ), }, { - QCOM_MODULE: IndexCopy(skip_mutable_buffer=True), # noqa: F405 + QCOM_MODULE: IndexCopy( # noqa: F405 + copy_dim=2, skip_mutable_buffer=False + ), + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2], dtype=torch.int64), + torch.randn([1, 1024, 1, 64]), + ), + }, + { + QCOM_MODULE: IndexCopy( # noqa: F405 + copy_dim=2, skip_mutable_buffer=False + ), + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2, 5], dtype=torch.int64), + torch.randn([1, 1024, 2, 64]), + ), + }, + { + QCOM_MODULE: IndexCopy( # noqa: F405 + copy_dim=1, skip_mutable_buffer=True + ), QCOM_SAMPLE_INPUTS: ( torch.tensor([2], dtype=torch.int64), torch.randn([1, 1, 12, 64]), ), }, + { + QCOM_MODULE: IndexCopy( # noqa: F405 + copy_dim=2, skip_mutable_buffer=True + ), + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2], dtype=torch.int64), + torch.randn([1, 1024, 1, 64]), + ), + }, + { + QCOM_MODULE: IndexCopy( # noqa: F405 + copy_dim=2, skip_mutable_buffer=True + ), + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2, 5], dtype=torch.int64), + torch.randn([1, 1024, 2, 64]), + ), + }, ] for i, test in enumerate(test_comb): with self.subTest(i=i): @@ -1907,19 +1947,59 @@ def test_qnn_backend_index(self): def test_qnn_backend_index_copy(self): test_comb = [ { - QCOM_MODULE: IndexCopy(skip_mutable_buffer=False), # noqa: F405 + QCOM_MODULE: IndexCopy( # noqa: F405 + copy_dim=1, skip_mutable_buffer=False + ), QCOM_SAMPLE_INPUTS: ( torch.tensor([2], dtype=torch.int64), torch.randn([1, 1, 12, 64]), ), }, { - QCOM_MODULE: IndexCopy(skip_mutable_buffer=True), # noqa: F405 + QCOM_MODULE: IndexCopy( # noqa: F405 + copy_dim=2, skip_mutable_buffer=False + ), + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2], dtype=torch.int64), + torch.randn([1, 1024, 1, 64]), + ), + }, + { + QCOM_MODULE: IndexCopy( # noqa: F405 + copy_dim=2, skip_mutable_buffer=False + ), + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2, 5], dtype=torch.int64), + torch.randn([1, 1024, 2, 64]), + ), + }, + { + QCOM_MODULE: IndexCopy( # noqa: F405 + copy_dim=1, skip_mutable_buffer=True + ), QCOM_SAMPLE_INPUTS: ( torch.tensor([2], dtype=torch.int64), torch.randn([1, 1, 12, 64]), ), }, + { + QCOM_MODULE: IndexCopy( # noqa: F405 + copy_dim=2, skip_mutable_buffer=True + ), + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2], dtype=torch.int64), + torch.randn([1, 1024, 1, 64]), + ), + }, + { + QCOM_MODULE: IndexCopy( # noqa: F405 + copy_dim=2, skip_mutable_buffer=True + ), + QCOM_SAMPLE_INPUTS: ( + torch.tensor([2, 5], dtype=torch.int64), + torch.randn([1, 1024, 2, 64]), + ), + }, ] for i, test in enumerate(test_comb): with self.subTest(i=i): @@ -4909,6 +4989,39 @@ def test_swin_transformer(self): self.assertGreaterEqual(msg["top_1"], 60) self.assertGreaterEqual(msg["top_5"], 80) + def test_whisper(self): + if not self.required_envs(): + self.skipTest("missing required envs") + + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/whisper/whisper.py", + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--ip", + self.ip, + "--port", + str(self.port), + ] + if self.host: + cmds.extend(["--host", self.host]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertLessEqual(msg["wer"], 0.25) + class TestExampleQaihubScript(TestQNN): def test_utils_export(self): diff --git a/examples/qualcomm/CMakeLists.txt b/examples/qualcomm/CMakeLists.txt index 757c7518f0c..69fa9a0b0d4 100644 --- a/examples/qualcomm/CMakeLists.txt +++ b/examples/qualcomm/CMakeLists.txt @@ -90,6 +90,9 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama) # build qnn_mimi_decoder_runner add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/moshi) +# build qnn_whisper_runner for whisper +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/whisper) + # build qaihub_llama2_7b_runner and qaihub_llama3_8b_runner add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/qaihub_scripts/llama) diff --git a/examples/qualcomm/oss_scripts/whisper/CMakeLists.txt b/examples/qualcomm/oss_scripts/whisper/CMakeLists.txt new file mode 100644 index 00000000000..d3c28b218e3 --- /dev/null +++ b/examples/qualcomm/oss_scripts/whisper/CMakeLists.txt @@ -0,0 +1,46 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +# preprocess qnn runner src files for whisper +set(_qnn_whisper_runner__srcs + ${CMAKE_CURRENT_LIST_DIR}/qnn_whisper_runner.cpp + ${CMAKE_CURRENT_LIST_DIR}/runner/decoder.cpp + ${CMAKE_CURRENT_LIST_DIR}/runner/decoder.h + ${CMAKE_CURRENT_LIST_DIR}/runner/encoder.cpp + ${CMAKE_CURRENT_LIST_DIR}/runner/encoder.h + ${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp + ${CMAKE_CURRENT_LIST_DIR}/runner/runner.h + ${EXECUTORCH_ROOT}/extension/llm/sampler/sampler.cpp +) + +# build qnn whisper runner +add_executable(qnn_whisper_runner ${_qnn_whisper_runner__srcs}) +target_include_directories( + qnn_whisper_runner PUBLIC ${_common_include_directories} + ${EXECUTORCH_ROOT}/extension/llm/tokenizers/include +) + + +target_link_libraries( + qnn_whisper_runner + qnn_executorch_backend + executorch_core + extension_data_loader + extension_flat_tensor + extension_module + extension_tensor + full_portable_ops_lib + gflags + tokenizers +) + +target_compile_options( + qnn_whisper_runner PUBLIC ${_common_compile_options} +) +set_target_properties( + qnn_whisper_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'" +) diff --git a/examples/qualcomm/oss_scripts/whisper/qnn_whisper_runner.cpp b/examples/qualcomm/oss_scripts/whisper/qnn_whisper_runner.cpp new file mode 100644 index 00000000000..e61b2f444c0 --- /dev/null +++ b/examples/qualcomm/oss_scripts/whisper/qnn_whisper_runner.cpp @@ -0,0 +1,122 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * @file + * + * This tool can run whisper with Qualcomm AI Engine Direct. + * + */ + +#include +#include +#include +#include +#include +#include + +DEFINE_string( + model_path, + "whisper_qnn_16a8w.pte", + "Whisper model serialized in flatbuffer format."); + +DEFINE_string( + tokenizer_json_path, + "tokenizer.json", + "The tokenizer is saved from WhisperTokenize.save_pretrained for tokenizer."); +DEFINE_string( + input_list_path, + "input_list.txt", + "Input list storing file name of encoded results."); +DEFINE_int32( + seq_len, + 128, + "Maximum sequence length for the generated output. Defaults to use the model's `max_cache_size` attribute. Will be truncated to maximal cache size if larger than `max_cache_size`."); + +DEFINE_string( + output_folder_path, + "outputs", + "Executorch inference data output path."); + +std::vector>> parse_input_list_file( + const std::string& input_list_path) { + // Parsing an input list file to obtain multiple inferences of multiple data. + std::vector>> bufs; + std::ifstream input_list(input_list_path); + auto split = [](std::string s, std::string delimiter) { + size_t pos_start = 0, pos_end, delim_len = delimiter.length(); + std::string token; + std::vector res; + + while ((pos_end = s.find(delimiter, pos_start)) != std::string::npos) { + token = s.substr(pos_start, pos_end - pos_start); + pos_start = pos_end + delim_len; + res.push_back(token); + } + res.push_back(s.substr(pos_start)); + return res; + }; + + if (!input_list.is_open()) { + std::cerr << "Unable to open file" << std::endl; + return bufs; + } + std::string file_path; + while (std::getline(input_list, file_path)) { + auto input_files = split(file_path, " "); + int num_inputs = input_files.size(); + if (num_inputs == 0) { + break; + } + bufs.emplace_back(); + bufs.back().resize(num_inputs); + for (int input_index = 0; input_index < num_inputs; ++input_index) { + std::ifstream fin(input_files[input_index], std::ios::binary); + fin.seekg(0, fin.end); + size_t file_size = fin.tellg(); + bufs.back()[input_index].resize(file_size); + + fin.seekg(0, fin.beg); + if (!fin.read(bufs.back()[input_index].data(), file_size)) { + std::cerr << "Error: Could not read file." << std::endl; + return bufs; + } + fin.close(); + } + } + + input_list.close(); + return bufs; +} + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + // create llama runner + example::Runner runner(FLAGS_model_path, FLAGS_tokenizer_json_path); + + std::vector>> multi_turns_input_buffers = + parse_input_list_file(FLAGS_input_list_path); + for (int iter = 0; iter < multi_turns_input_buffers.size(); ++iter) { + std::vector bufs; + bufs.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char + auto callback = [&](const std::string& piece) { + for (const char c : piece) { + bufs.push_back(c); + } + }; + // generate tokens + runner.transcribe(FLAGS_seq_len, multi_turns_input_buffers[iter], callback); + auto output_file_name = + FLAGS_output_folder_path + "/output_" + std::to_string(iter) + ".txt"; + std::ofstream fout(output_file_name); + fout.write(bufs.data(), bufs.size()); + fout.close(); + } + + return 0; +} diff --git a/examples/qualcomm/oss_scripts/whisper/runner/decoder.cpp b/examples/qualcomm/oss_scripts/whisper/runner/decoder.cpp new file mode 100644 index 00000000000..8179ae99d03 --- /dev/null +++ b/examples/qualcomm/oss_scripts/whisper/runner/decoder.cpp @@ -0,0 +1,53 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +using executorch::aten::Tensor; +using executorch::extension::Module; +using executorch::extension::TensorPtr; +using executorch::runtime::Error; +using executorch::runtime::Result; + +namespace example { +WhisperDecoder::WhisperDecoder(const std::string& model_path) { + module_ = std::make_unique( + model_path, Module::LoadMode::MmapUseMlockIgnoreErrors); + ET_LOG(Info, "creating decoder module: model_path=%s", model_path.c_str()); +} + +bool WhisperDecoder::is_method_loaded() const { + return module_->is_method_loaded(kDecoderForwardName); +} + +Error WhisperDecoder::load() { + if (is_method_loaded()) { + return Error::Ok; + } + return module_->load_method(kDecoderForwardName); +} +Result WhisperDecoder::step( + TensorPtr& input_ids, + TensorPtr& attention_mask, + TensorPtr& encoder_hidden_states, + TensorPtr& cache_position) { + auto outputs_res = module_->execute( + kDecoderForwardName, + {input_ids, attention_mask, encoder_hidden_states, cache_position}); + ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); + ET_CHECK_MSG( + outputs_res.get().size() == 1, + "More then one output returned from executing decoder."); + ET_CHECK_MSG( + outputs_res.get()[0].isTensor(), + "Non Tensor Output returned from executing decoder"); + + // Return the logits tensor + return outputs_res.get()[0].toTensor(); +} +} // namespace example diff --git a/examples/qualcomm/oss_scripts/whisper/runner/decoder.h b/examples/qualcomm/oss_scripts/whisper/runner/decoder.h new file mode 100644 index 00000000000..ba5e23c7039 --- /dev/null +++ b/examples/qualcomm/oss_scripts/whisper/runner/decoder.h @@ -0,0 +1,51 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace example { + +class WhisperDecoder { + public: + explicit WhisperDecoder(const std::string& model_path); + + bool is_method_loaded() const; + executorch::runtime::Error load(); + executorch::runtime::Result step( + executorch::extension::TensorPtr& input_ids, + executorch::extension::TensorPtr& attention_mask, + executorch::extension::TensorPtr& encoder_hidden_states, + executorch::extension::TensorPtr& cache_position); + executorch::runtime::Result> method_names() { + return module_->method_names(); + } + executorch::runtime::Result get( + const std::string& method_name) { + return module_->get(method_name); + } + + executorch::runtime::Result> execute( + const std::string& method_name) { + return module_->execute(method_name); + } + + private: + std::unique_ptr module_; + static constexpr const char* kDecoderForwardName = "decoder"; +}; + +} // namespace example diff --git a/examples/qualcomm/oss_scripts/whisper/runner/encoder.cpp b/examples/qualcomm/oss_scripts/whisper/runner/encoder.cpp new file mode 100644 index 00000000000..778a54d73b0 --- /dev/null +++ b/examples/qualcomm/oss_scripts/whisper/runner/encoder.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +using executorch::aten::Tensor; +using executorch::extension::Module; +using executorch::extension::TensorPtr; +using executorch::runtime::Error; +using executorch::runtime::Result; +namespace example { +WhisperEncoder::WhisperEncoder(const std::string& model_path) { + module_ = std::make_unique( + model_path, Module::LoadMode::MmapUseMlockIgnoreErrors); + ET_LOG(Info, "creating encoder module: model_path=%s", model_path.c_str()); +} + +bool WhisperEncoder::is_method_loaded() const { + return module_->is_method_loaded(kEncoderForwardName); +} + +Error WhisperEncoder::load() { + if (is_method_loaded()) { + return Error::Ok; + } + return module_->load_method(kEncoderForwardName); +} +Result WhisperEncoder::encode(TensorPtr& input_feature) { + auto outputs_res = module_->execute(kEncoderForwardName, input_feature); + ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); + ET_CHECK_MSG( + outputs_res.get().size() == 1, + "More then one output returned from executing encoder."); + ET_CHECK_MSG( + outputs_res.get()[0].isTensor(), + "Non Tensor Output returned from executing encoder"); + + // Return the hidden state tensor + return outputs_res.get()[0].toTensor(); +} +} // namespace example diff --git a/examples/qualcomm/oss_scripts/whisper/runner/encoder.h b/examples/qualcomm/oss_scripts/whisper/runner/encoder.h new file mode 100644 index 00000000000..90d0d43dfcd --- /dev/null +++ b/examples/qualcomm/oss_scripts/whisper/runner/encoder.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include +#include +#include +#include +#include +#include + +namespace example { + +class WhisperEncoder { + public: + explicit WhisperEncoder(const std::string& model_path); + + bool is_method_loaded() const; + executorch::runtime::Error load(); + executorch::runtime::Result encode( + executorch::extension::TensorPtr& input_feature); + + private: + std::unique_ptr module_; + static constexpr const char* kEncoderForwardName = "encoder"; +}; + +} // namespace example diff --git a/examples/qualcomm/oss_scripts/whisper/runner/runner.cpp b/examples/qualcomm/oss_scripts/whisper/runner/runner.cpp new file mode 100644 index 00000000000..8cd75f433f7 --- /dev/null +++ b/examples/qualcomm/oss_scripts/whisper/runner/runner.cpp @@ -0,0 +1,221 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +using executorch::aten::ScalarType; +using executorch::aten::Tensor; +using executorch::extension::from_blob; +using executorch::extension::make_tensor_ptr; +using executorch::extension::llm::Sampler; +using executorch::extension::llm::time_in_ms; +using executorch::llm::kTopp; +using executorch::runtime::Error; +using executorch::runtime::Result; + +namespace example { +namespace { +static constexpr auto kDecoderStartTokenId = "decoder_start_token_id"; +static constexpr auto kEosId = "get_eos_id"; +static constexpr auto kMaxContextLen = "get_max_context_len"; +} // namespace +Runner::Runner( + const std::string& model_path, + const std::string& tokenizer_json_path) + : tokenizer_json_path_(tokenizer_json_path) { + encoder_ = std::make_unique(model_path); + decoder_ = std::make_unique(model_path); + tokenizer_ = std::make_unique(); +} +bool Runner::is_loaded() const { + return encoder_->is_method_loaded() && decoder_->is_method_loaded() && + tokenizer_->is_loaded() && sampler_; +} + +Error Runner::load() { + if (is_loaded()) { + return Error::Ok; + } + ET_CHECK_OK_OR_RETURN_ERROR(encoder_->load()); + ET_CHECK_OK_OR_RETURN_ERROR(decoder_->load()); + if (tokenizer_->load(tokenizer_json_path_) != tokenizers::Error::Ok) { + ET_LOG( + Error, + "Failed to load tokenizer with %s", + tokenizer_json_path_.c_str()); + return Error::Internal; + } + eos_ids_ = std::make_unique>( + std::unordered_set{tokenizer_->eos_tok()}); + // create sampler + sampler_ = std::make_unique( + tokenizer_->vocab_size(), + 0, + kTopp, + static_cast(std::time(nullptr))); + + // Initialize metadata with default values + metadata_ = { + {kDecoderStartTokenId, 50258}, + {kMaxContextLen, 128}, + }; + + // Read metadata from the model + auto method_names_result = decoder_->method_names(); + if (method_names_result.error() != Error::Ok) { + ET_LOG(Error, "Failed reading method names"); + return Error::Internal; + } + const auto method_names = method_names_result.get(); + + for (auto& [method_name, value] : metadata_) { + if (method_names.count(method_name)) { + auto get_result = decoder_->get(method_name); + value = + get_result.get().toScalar().to(); + } else { + ET_LOG( + Info, + "Method %s not found, using the default value %" PRId64, + method_name.c_str(), + value); + } + ET_LOG(Info, "Metadata: %s = %" PRId64, method_name.c_str(), value); + } + + // Get EOS IDs if available + if (method_names.count(kEosId)) { + eos_ids_->clear(); + auto execute_result = decoder_->execute(kEosId); + if (execute_result.error() != Error::Ok) { + ET_LOG(Error, "Failed to execute %s", kEosId); + return Error::Internal; + } + for (const auto& eos_id : execute_result.get()) { + auto value = eos_id.toScalar().to(); + eos_ids_->emplace(value); + ET_LOG(Info, "eos_id = %" PRId64, value); + } + } + + return Error::Ok; +} +uint64_t Runner::logits_to_token( + const executorch::aten::Tensor& logits_tensor) { + return sampler_->sample(logits_tensor.data_ptr()); +} + +Error Runner::transcribe( + int32_t seq_len, + std::vector>& inputs, + std::function token_callback) { + if (!is_loaded()) { + stats_.model_load_start_ms = time_in_ms(); + ET_CHECK_OK_OR_RETURN_ERROR(load()); + stats_.model_load_end_ms = time_in_ms(); + } + ET_CHECK_MSG(inputs.size() == 1, "The input size of whisper should be one."); + + ET_LOG(Info, "Start Encoding"); + stats_.encoder_inference_start_ms = time_in_ms(); + auto input_features_tensor_ptr = from_blob( + inputs[0].data(), + // (1, processor.feature_extractor.feature_size, + // processor.feature_extractor.nb_max_frames) + {1, 80, 3000}, + ScalarType::Float); + Result encoder_out = encoder_->encode(input_features_tensor_ptr); + auto encoder_out_tensor_ptr = make_tensor_ptr(encoder_out.get()); + stats_.encoder_inference_end_ms = time_in_ms(); + auto max_seq_len = metadata_.at(kMaxContextLen); + + seq_len = (seq_len > 0 && seq_len <= max_seq_len) ? seq_len : max_seq_len; + + int64_t pos = 0; + num_generated_token_ = 0; + uint64_t prev_token = metadata_.at(kDecoderStartTokenId), + cur_token = prev_token; + ET_LOG(Info, "Start Decoding"); + std::vector attention_mask_data(max_seq_len, -255.0); + stats_.decoder_inference_start_ms = time_in_ms(); + while (pos < seq_len) { + attention_mask_data[pos] = 0; + auto decoder_input_ids_tensor_ptr = + from_blob(&cur_token, {1, 1}, ScalarType::Long); + auto pos_tensor_ptr = from_blob(&pos, {1}, ScalarType::Long); + + auto attention_mask_tensor_ptr = from_blob( + attention_mask_data.data(), + {1, 1, 1, static_cast(max_seq_len)}, + ScalarType::Float); + Result logits = decoder_->step( + decoder_input_ids_tensor_ptr, + attention_mask_tensor_ptr, + encoder_out_tensor_ptr, + pos_tensor_ptr); + + prev_token = cur_token; + cur_token = logits_to_token(logits.get()); + ++pos; + + if (token_callback) { + token_callback( + ET_UNWRAP_TOKENIZER(tokenizer_->decode(prev_token, cur_token))); + } + if (eos_ids_->count(cur_token) > 0) { + ET_LOG(Info, "\nReached to the end of generation"); + break; + } + } + stats_.decoder_inference_end_ms = time_in_ms(); + if (pos == seq_len) { + ET_LOG(Info, "\nSequence length (%i tokens) reached!", seq_len); + } + num_generated_token_ = pos; + print_performance(); + return Error::Ok; +} + +Error Runner::print_performance() { + ET_LOG(Info, "\tTotal Generated token:\t\t\t\t%ld", num_generated_token_); + + ET_LOG( + Info, + "\tModel Load Time:\t\t\t\t%f (seconds)", + ((double)(stats_.model_load_end_ms - stats_.model_load_start_ms) / + stats_.SCALING_FACTOR_UNITS_PER_SECOND)); + + ET_LOG( + Info, + "\tEncoding Time:\t\t\t\t\t%f (seconds)", + ((double)(stats_.encoder_inference_end_ms - + stats_.encoder_inference_start_ms) / + stats_.SCALING_FACTOR_UNITS_PER_SECOND)); + + ET_LOG( + Info, + "\tDecoding Time:\t\t\t%f (seconds)", + ((double)(stats_.decoder_inference_end_ms - + stats_.decoder_inference_start_ms) / + stats_.SCALING_FACTOR_UNITS_PER_SECOND)); + + ET_LOG( + Info, + "\tAverage Decoding Time:\t\t\t%f (seconds)", + ((double)((stats_.decoder_inference_end_ms - + stats_.decoder_inference_start_ms) / + num_generated_token_) / + (stats_.SCALING_FACTOR_UNITS_PER_SECOND))); + + return Error::Ok; +} + +} // namespace example diff --git a/examples/qualcomm/oss_scripts/whisper/runner/runner.h b/examples/qualcomm/oss_scripts/whisper/runner/runner.h new file mode 100644 index 00000000000..de7c38d0e32 --- /dev/null +++ b/examples/qualcomm/oss_scripts/whisper/runner/runner.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// A simple whisper runner that includes preprocessing and post processing +// logic. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace example { + +class Runner { + public: + explicit Runner( + const std::string& model_path, + const std::string& tokenizer_json_path); + + struct Stats { + // Scaling factor for timestamps - in this case, we use ms. + const long SCALING_FACTOR_UNITS_PER_SECOND = 1000; + // Time stamps for the different stages of the execution + // model_load_start_ms: Model loading time + long model_load_start_ms; + long model_load_end_ms; + + // encoder inference time + long encoder_inference_start_ms = 0; + long encoder_inference_end_ms = 0; + + // decoder inference time + long decoder_inference_start_ms = 0; + long decoder_inference_end_ms = 0; + }; + + bool is_loaded() const; + executorch::runtime::Error load(); + executorch::runtime::Error transcribe( + int32_t seq_len, + std::vector>& inputs, + std::function token_callback = {}); + + private: + executorch::runtime::Error print_performance(); + uint64_t logits_to_token(const executorch::aten::Tensor& logits_tensor); + // model + std::unique_ptr encoder_; + std::unique_ptr decoder_; + std::unique_ptr tokenizer_; + std::unique_ptr sampler_; + std::string tokenizer_json_path_; + + std::unordered_map metadata_; + std::unique_ptr> eos_ids_; + + int64_t num_generated_token_ = 0; + Stats stats_; +}; + +} // namespace example diff --git a/examples/qualcomm/oss_scripts/whisper/whisper.py b/examples/qualcomm/oss_scripts/whisper/whisper.py new file mode 100644 index 00000000000..4b0d681f6ec --- /dev/null +++ b/examples/qualcomm/oss_scripts/whisper/whisper.py @@ -0,0 +1,510 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import getpass +import json +import logging +import os +import re +import subprocess +from functools import partial +from multiprocessing.connection import Client + +import torch +from executorch.backends.qualcomm._passes import TagQuantIO + +from executorch.backends.qualcomm._passes.qnn_pass_manager import ( + get_capture_program_passes, +) +from executorch.backends.qualcomm.builders.utils import is_graph_output + +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset +from executorch.backends.qualcomm.utils.constants import ( + QCOM_PASS_ACTIVATE_KEY, + QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY, +) +from executorch.backends.qualcomm.utils.utils import ( + convert_linear_to_conv2d, + generate_htp_compiler_spec, + generate_qnn_executorch_compiler_spec, + get_soc_to_chipset_map, + to_edge_transform_and_lower_to_qnn, +) + +from executorch.devtools.backend_debug import print_delegation_info +from executorch.examples.qualcomm.oss_scripts.whisper.whisper_model import ( + Seq2SeqLMDecoderExportableModuleWithStaticCache, + Seq2SeqLMEncoderExportableModule, +) + +from executorch.examples.qualcomm.utils import ( + make_output_dir, + make_quantizer, + parse_skip_delegation_node, + setup_common_args_and_variables, + SimpleADB, +) +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass +from torchao.quantization.pt2e import MinMaxObserver +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, AutoTokenizer + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) +logging.getLogger().setLevel(logging.INFO) + +WHISPER_PTE_FILENAME = "whisper_qnn_16a8w.pte" +ENCODER = "encoder" +DECODER = "decoder" + + +def get_dataset(data_size): + from datasets import load_dataset + + dataset = load_dataset( + "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" + ) + processor = AutoProcessor.from_pretrained("openai/whisper-tiny") + + # prepare input data + inputs, target, input_list = [], [], "" + for index, data in enumerate(dataset): + if index >= data_size: + break + sample = data["audio"] + feature = processor( + sample["array"], + return_tensors="pt", + truncation=False, + sampling_rate=sample["sampling_rate"], + ).input_features + inputs.append((feature,)) + target.append(data["text"]) + input_list += f"input_{index}_0.raw\n" + + return inputs, input_list, target + + +def calibrate( + max_seq_length, + tokenizer, + whisper_decoder, + fx_graph_module_encoder, + fx_graph_module_decoder, + calibration_inputs, + decoder_start_token_id=50258, + eos_token_id=50257, +): + for i, calibration_input in enumerate(calibration_inputs): + generated_ids = [] + encoder_output = fx_graph_module_encoder(*calibration_input) + decoder_input_ids = torch.tensor([[decoder_start_token_id]], dtype=torch.long) + _, atten_mask, _, _ = whisper_decoder.get_example_inputs() + + # Generate tokens one by one + for j in range(max_seq_length - 1): + atten_mask[:, :, :, j] = 0 + # Run decoder for next token prediction + logits = fx_graph_module_decoder( + decoder_input_ids, + atten_mask, + encoder_output, + torch.tensor([j], dtype=torch.long), + ) + # Get next token + next_token = torch.argmax(logits[:, -1, :], dim=-1).item() + generated_ids.append(next_token) + # Update input for next iteration + decoder_input_ids = torch.tensor([[next_token]], dtype=torch.long) + # Check if EOS token + if next_token == eos_token_id: + break + # skip_special_tokens=False to align with the results of runner + logging.info( + f"Generated result for {i} calibration: {tokenizer.decode(generated_ids, skip_special_tokens=False)}" + ) + + +def eval_metric(preds, target_strs): + from torchmetrics.text import WordErrorRate + + def clean_text(rgx_list, text): + new_text = text + for rgx_match in rgx_list: + new_text = re.sub(rgx_match, "", new_text) + return new_text + + special_strs = ["<|en|>", "<|transcribe|>", "<|notimestamps|>", "<|endoftext|>"] + special_strs_escape = [re.escape(special_str) for special_str in special_strs] + pred_str = [clean_text(special_strs_escape, pred).upper() for pred in preds] + + wer = WordErrorRate() + return wer(pred_str, target_strs) + + +class Whisper: + def __init__( + self, whisper_model, batch_size=1, max_cache_length=1024, max_seq_length=None + ): + if max_seq_length is None: + # Default to max_cache_size if max_seq_len is not specified + self.max_seq_length = max_cache_length + elif max_seq_length > max_cache_length: + logging.warning( + f"max_seq_length={max_seq_length} is larger than max_cache_length={max_cache_length}. Generating tokens will be truncated to max_cache_length." + ) + self.max_seq_length = max_cache_length + else: + self.max_seq_length = max_seq_length + self.whisper_model = whisper_model + self.config = whisper_model.config + self.head_dim = ( + self.config.head_dim + if hasattr(self.config, "head_dim") + else self.config.hidden_size // self.config.num_attention_heads + ) + + self.whisper_encoder = ( + Seq2SeqLMEncoderExportableModule(whisper_model.get_encoder()) + .to("cpu") + .eval() + ) + self.encoder_passes_job = get_capture_program_passes() + + self.whisper_decoder = ( + Seq2SeqLMDecoderExportableModuleWithStaticCache( + whisper_model=whisper_model, + max_cache_length=self.max_seq_length, + batch_size=batch_size, + ) + .to("cpu") + .eval() + ) + # To improve the performance + self.whisper_decoder = convert_linear_to_conv2d(self.whisper_decoder) + self.decoder_passes_job = get_capture_program_passes() + self.exported_whisper_encoder = None + self.exported_whisper_decoder = None + self.has_quant_io = False + + def _tag_ios(self, node, fixed_point_type): + if not self.has_quant_io: + return + + quant_io_type = None + if node.op == "placeholder" and "static_cache_" in node.name: + quant_io_type = fixed_point_type + + if is_graph_output(node): + # shape of k caches and v caches + if node.meta["val"].size()[-2:] in { + (self.max_seq_length, self.head_dim), + }: + quant_io_type = fixed_point_type + + return quant_io_type + + def quantize( + self, calibration_inputs, quant_dtype, tokenizer, custom_annotations=() + ): + self.quant_dtype = quant_dtype + self.has_quant_io = True + + # Need to set per_channel_linear=True for encoder to enhance accuracy + quantizer = make_quantizer( + quant_dtype=quant_dtype, + per_channel_conv=True, + per_channel_linear=True, + act_observer=MinMaxObserver, + custom_annotations=custom_annotations, + ) + + with torch.no_grad(): + self.exported_whisper_encoder = torch.export.export( + self.whisper_encoder, + self.whisper_encoder.get_example_inputs(), + strict=True, + ).module() + self.exported_whisper_decoder = torch.export.export( + self.whisper_decoder, + self.whisper_decoder.get_example_inputs(), + strict=True, + ).module() + + self.exported_whisper_encoder = prepare_pt2e( + self.exported_whisper_encoder, quantizer + ) + self.exported_whisper_decoder = prepare_pt2e( + self.exported_whisper_decoder, quantizer + ) + + logging.info("Quantizing the model...") + + calibrate( + self.max_seq_length, + tokenizer, + self.whisper_decoder, + self.exported_whisper_encoder, + self.exported_whisper_decoder, + calibration_inputs, + decoder_start_token_id=getattr( + self.config, "decoder_start_token_id", None + ), + eos_token_id=getattr(self.config, "eos_token_id", None), + ) + + self.exported_whisper_encoder = convert_pt2e(self.exported_whisper_encoder) + self.exported_whisper_decoder = convert_pt2e(self.exported_whisper_decoder) + + self.decoder_passes_job[TagQuantIO][QCOM_PASS_ACTIVATE_KEY] = True + self.decoder_passes_job[TagQuantIO][QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY][ + "get_quant_io_dtype_fn" + ] = partial(self._tag_ios, fixed_point_type=torch.uint16) + + def lowering_modules( + self, + workspace, + use_fp16=False, + soc_model=QcomChipset.SM8650, + skip_node_id_set=None, + skip_node_op_set=None, + verbose=True, + ): + logging.info("Lowering the model...") + executorch_config = ExecutorchBackendConfig( + memory_planning_pass=MemoryPlanningPass( + alloc_graph_input=True, + alloc_graph_output=True, + ), + extract_delegate_segments=True, + ) + with torch.no_grad(): + # backend option + backend_options = generate_htp_compiler_spec(use_fp16=use_fp16) + compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=soc_model, + backend_options=backend_options, + ) + + whisper_edge_prog_mgr = to_edge_transform_and_lower_to_qnn( + { + ENCODER: self.exported_whisper_encoder, + DECODER: self.exported_whisper_decoder, + }, + { + ENCODER: self.whisper_encoder.get_example_inputs(), + DECODER: self.whisper_decoder.get_example_inputs(), + }, + {ENCODER: compiler_specs, DECODER: compiler_specs}, + constant_methods=self.whisper_decoder.get_metadata(), + passes_job={ + ENCODER: get_capture_program_passes(), + DECODER: self.decoder_passes_job, + }, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + skip_mutable_buffer=False, + ) + + if verbose: + print_delegation_info( + whisper_edge_prog_mgr.exported_program(ENCODER).graph_module + ) + print_delegation_info( + whisper_edge_prog_mgr.exported_program(DECODER).graph_module + ) + whisper_edge_prog_mgr = whisper_edge_prog_mgr.to_executorch( + config=executorch_config + ) + with open(f"{workspace}/{WHISPER_PTE_FILENAME}", "wb") as file: + whisper_edge_prog_mgr.write_to_file(file) + + +def compile_whisper(args, inputs): + skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + if not args.compile_only and args.device is None: + raise RuntimeError( + "device serial is required if not compile only. " + "Please specify a device serial by -s/--device argument." + ) + tokenizer = AutoTokenizer.from_pretrained("openai/whisper-tiny") + module = ( + AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-tiny") + .to("cpu") + .eval() + ) + + max_cache_length = 1024 + batch_size = 1 + whisper = Whisper( + module, + batch_size=batch_size, + max_cache_length=max_cache_length, + max_seq_length=args.max_seq_len, + ) + + whisper.quantize(inputs, QuantDtype.use_16a8w, tokenizer) + whisper.lowering_modules( + args.artifact, + use_fp16=False, + soc_model=get_soc_to_chipset_map()[args.model], + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + ) + + +def inference_whisper(args, inputs, input_list, target): + workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/whisper" + tokenizer = AutoTokenizer.from_pretrained("openai/whisper-tiny") + tokenizer_json = tokenizer.save_pretrained(args.artifact)[-1] + pte_path = ( + f"{args.pre_gen_pte}/{WHISPER_PTE_FILENAME}" + if args.pre_gen_pte + else f"{args.artifact}/{WHISPER_PTE_FILENAME}" + ) + + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + outputs = [] + + def post_process(): + for i in range(len(inputs)): + with open(f"{args.artifact}/outputs/output_{i}.txt", "r") as f: + outputs.append(f.read()) + + seq_len = args.max_seq_len + runner_args = " ".join( + [ + f"--model_path {WHISPER_PTE_FILENAME}", + f"--tokenizer_json_path {os.path.basename(tokenizer_json)}", + "--input_list_path input_list.txt", + f"--seq_len {seq_len}", + "--output_folder_path outputs", + ] + ) + + if args.enable_x86_64: + # x86 emulator is intended for CI and not performance. + qnn_sdk = os.getenv("QNN_SDK_ROOT") + target = "x86_64-linux-clang" + runner_cmd = " ".join( + [ + f"export LD_LIBRARY_PATH={qnn_sdk}/lib/{target}/:{args.build_folder}/lib &&", + f"./{args.build_folder}/examples/qualcomm/oss_scripts/whisper/qnn_whisper_runner", + runner_args, + ] + ) + subprocess.run( + runner_cmd, + shell=True, + executable="/bin/bash", + capture_output=True, + ) + post_process() + else: + runner_cmd = " ".join( + [ + f"cd {workspace} &&", + "./qnn_whisper_runner", + runner_args, + ] + ) + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=pte_path, + workspace=workspace, + device_id=args.device, + host_id=args.host, + soc_model=args.model, + shared_buffer=args.shared_buffer, + runner="examples/qualcomm/oss_scripts/whisper/qnn_whisper_runner", + ) + # No pregen inputs, input_list is not required + adb.push(inputs=inputs, input_list=input_list, files=[tokenizer_json]) + adb.execute(custom_runner_cmd=runner_cmd) + + adb.pull(output_path=args.artifact, callback=post_process) + wer = eval_metric(outputs, target) + + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send( + json.dumps( + { + "wer": float(wer), + } + ) + ) + else: + logging.info(f"Wer: {wer}") + for idx, output in enumerate(outputs): + logging.info(f"Results[{idx}]:\n{output}") + + +if __name__ == "__main__": + parser = setup_common_args_and_variables() + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " + "Default ./whisper", + default="./whisper", + type=str, + ) + + parser.add_argument( + "--max_seq_len", + help="Maximum sequence length for the generated output. Defaults to use the model's `max_cache_size` attribute. Will be truncated to maximal cache size if larger than `max_cache_size`.", + default=1024, + type=int, + ) + + parser.add_argument( + "--pre_gen_pte", + help="Run the pre-generated llama in the given directory.", + type=str, + ) + + args = parser.parse_args() + + if args.compile_only and args.pre_gen_pte: + exit("Cannot set both compile_only and pre_gen_pte as true") + + data_num = 20 + if args.ci: + inputs = [(torch.rand(1, 80, 3000),)] + logging.warning( + "This option is for CI to verify the export flow. It uses random input and will result in poor accuracy." + ) + else: + inputs, input_list, target = get_dataset(data_num) + + if args.pre_gen_pte: + inference_whisper(args, inputs, input_list, target) + exit(f"Finish the running pre_gen_pte from {args.pre_gen_pte}") + + if args.compile_only: + compile_whisper(args, inputs) + exit(f"Finish compile_only and save to {args.artifact}") + + try: + compile_whisper(args, inputs) + inference_whisper(args, inputs, input_list, target) + except Exception as e: + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"Error": str(e)})) + else: + raise Exception(e) diff --git a/examples/qualcomm/oss_scripts/whisper/whisper_model.py b/examples/qualcomm/oss_scripts/whisper/whisper_model.py new file mode 100644 index 00000000000..ec0e96cae12 --- /dev/null +++ b/examples/qualcomm/oss_scripts/whisper/whisper_model.py @@ -0,0 +1,101 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from transformers import StaticCache, WhisperForConditionalGeneration + + +class Seq2SeqLMEncoderExportableModule(torch.nn.Module): + """ + A wrapper module designed to make a Seq2Seq LM encoder exportable with `torch.export`. + This module ensures that the exported encoder model is compatible with ExecuTorch. + """ + + def __init__(self, encoder_model): + super().__init__() + self.encoder = encoder_model + + def forward(self, input_ids): + return self.encoder(input_ids).last_hidden_state + + def get_example_inputs(self): + return (torch.rand(1, 80, 3000),) + + def get_metadata(self): + return {} + + +class Seq2SeqLMDecoderExportableModuleWithStaticCache(torch.nn.Module): + """ + A wrapper module designed to make a Seq2Seq LM decoder exportable with `torch.export`, + specifically for use with static caching. This module ensures the exported decoder + is compatible with ExecuTorch. + """ + + def __init__(self, whisper_model, max_cache_length, batch_size): + super().__init__() + + # Get the decoder component + self.decoder = whisper_model.get_decoder() + if isinstance(whisper_model, WhisperForConditionalGeneration): + self.proj_out = whisper_model.proj_out + else: + self.proj_out = whisper_model.lm_head + self.config = whisper_model.config + self.batch_size = batch_size + self.max_cache_length = max_cache_length + + # Initialize static cache + self.static_cache = StaticCache( + config=self.config, + max_batch_size=batch_size, + max_cache_len=max_cache_length, + device="cpu", + dtype=torch.float32, + ) + + # Register cache buffers to make them exportable + for i in range(len(self.static_cache.key_cache)): + self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i]) + self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i]) + + def forward( + self, decoder_input_ids, attention_mask, encoder_hidden_states, cache_position + ): + # Get outputs from decoder + outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=self.static_cache, + use_cache=True, + cache_position=cache_position, + ) + + # Apply linear projection (lm head) to obtain logits + logits = self.proj_out(outputs[0]) + return logits + + def get_example_inputs(self): + input_ids = torch.tensor([[0]], dtype=torch.long) + encoder_hidden_states = torch.rand(1, 1500, 384) + cache_position = torch.tensor([0], dtype=torch.long) + atten_mask = torch.full((1, self.max_cache_length), torch.tensor(-255.0)) + atten_mask *= torch.arange(self.max_cache_length) > cache_position.reshape( + -1, 1 + ) + atten_mask = atten_mask[None, None, :, :].expand(self.batch_size, 1, -1, -1) + return (input_ids, atten_mask, encoder_hidden_states, cache_position) + + def get_metadata(self): + return { + "get_eos_id": getattr(self.config, "eos_token_id", None), + "get_max_context_len": self.max_cache_length, + "decoder_start_token_id": getattr( + self.config, "decoder_start_token_id", None + ), + }