From 73e2f7cdc025acd41654fd5244fed5238089e974 Mon Sep 17 00:00:00 2001 From: shewu-quic Date: Fri, 8 Nov 2024 12:00:41 +0800 Subject: [PATCH 1/2] Qualcomm AI Engine Direct - Optimize QNN embedding op for llama summary: - Change the dtype of the token from int64 to in32 Int32 is Qnn HTP friendly. It will significantly speed up qnn embedding operations due to matching backend optimizations. --- examples/models/llama/export_llama_lib.py | 9 +++++++++ examples/models/llama/runner/runner.cpp | 4 ++++ extension/llm/runner/text_prefiller.cpp | 13 ++++++++----- extension/llm/runner/text_prefiller.h | 2 ++ extension/llm/runner/text_token_generator.h | 9 +++++++-- 5 files changed, 30 insertions(+), 7 deletions(-) diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 23b3589c2a0..d1933f3434e 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -519,6 +519,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager: checkpoint=checkpoint_path, checkpoint_dir=checkpoint_dir, params_path=params_path, + use_int32_token=True if args.qnn else False, use_kv_cache=args.use_kv_cache, use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache, generate_full_logits=args.generate_full_logits, @@ -746,6 +747,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 def _load_llama_model_metadata( weight_type: WeightType, + use_int32_token: bool, use_kv_cache: bool, use_sdpa_with_kv_cache: bool, enable_dynamic_shape: bool, @@ -759,6 +761,7 @@ def _load_llama_model_metadata( "get_max_seq_len": model_args.max_seq_len, "get_n_layers": model_args.n_layers, "get_vocab_size": model_args.vocab_size, + "use_int32_token": use_int32_token, "use_kv_cache": use_kv_cache, "use_sdpa_with_kv_cache": use_sdpa_with_kv_cache, "enable_dynamic_shape": enable_dynamic_shape, @@ -779,6 +782,7 @@ def _load_llama_model( checkpoint: Optional[str] = None, checkpoint_dir: Optional[str] = None, params_path: str, + use_int32_token: bool = False, use_kv_cache: bool = False, use_sdpa_with_kv_cache: bool = False, generate_full_logits: bool = False, @@ -852,6 +856,10 @@ def _load_llama_model( else: raise ValueError(f"Unsupported dtype {dtype}") + if use_int32_token: + token = example_inputs[0].to(torch.int32) + example_inputs = (token,) + example_inputs[1:] + return LLMEdgeManager( model=model, modelname=modelname, @@ -870,6 +878,7 @@ def _load_llama_model( verbose=verbose, metadata=_load_llama_model_metadata( weight_type, + use_int32_token, use_kv_cache, use_sdpa_with_kv_cache, enable_dynamic_shape, diff --git a/examples/models/llama/runner/runner.cpp b/examples/models/llama/runner/runner.cpp index 42a1a632dc6..191282af418 100644 --- a/examples/models/llama/runner/runner.cpp +++ b/examples/models/llama/runner/runner.cpp @@ -34,6 +34,7 @@ static constexpr auto kMaxSeqLen = "get_max_seq_len"; static constexpr auto kVocabSize = "get_vocab_size"; static constexpr auto kUseKVCache = "use_kv_cache"; static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache"; +static constexpr auto kUseInt32Token = "use_int32_token"; } // namespace Runner::Runner( @@ -51,6 +52,7 @@ Runner::Runner( {kMaxSeqLen, 128}, {kUseKVCache, true}, {kUseSDPAWithKVCache, false}, + {kUseInt32Token, true}, }) { ET_LOG( Info, @@ -127,12 +129,14 @@ Error Runner::load() { temperature_); text_prefiller_ = std::make_unique( text_decoder_runner_.get(), + metadata_.at(kUseInt32Token), metadata_.at(kUseKVCache), metadata_.at(kEnableDynamicShape)); text_token_generator_ = std::make_unique( tokenizer_.get(), text_decoder_runner_.get(), + metadata_.at(kUseInt32Token), metadata_.at(kUseKVCache), std::move(eos_ids), &stats_); diff --git a/extension/llm/runner/text_prefiller.cpp b/extension/llm/runner/text_prefiller.cpp index 705583d638b..8a5b6780e5a 100644 --- a/extension/llm/runner/text_prefiller.cpp +++ b/extension/llm/runner/text_prefiller.cpp @@ -17,9 +17,11 @@ namespace llm { TextPrefiller::TextPrefiller( TextDecoderRunner* text_decoder_runner, + bool use_int32_token, bool use_kv_cache, bool enable_parallel_prefill) : text_decoder_runner_(text_decoder_runner), + use_int32_token_(use_int32_token), use_kv_cache_(use_kv_cache), enable_parallel_prefill_(enable_parallel_prefill) {} @@ -36,12 +38,13 @@ ::executorch::runtime::Result TextPrefiller::prefill( // store the token uint64_t cur_token; + exec_aten::ScalarType token_type = use_int32_token_ + ? exec_aten::ScalarType::Int + : exec_aten::ScalarType::Long; if (enable_parallel_prefill_ || !use_kv_cache_) { // initialize tensor wrappers - auto tokens = from_blob( - prompt_tokens.data(), - {1, num_prompt_tokens}, - exec_aten::ScalarType::Long); + auto tokens = + from_blob(prompt_tokens.data(), {1, num_prompt_tokens}, token_type); auto start_pos_tensor = from_blob(&start_pos, {1}, exec_aten::ScalarType::Long); @@ -60,7 +63,7 @@ ::executorch::runtime::Result TextPrefiller::prefill( cur_token = prompt_tokens[0]; // initialize tensor wrappers - auto tokens = from_blob(&cur_token, {1, 1}, exec_aten::ScalarType::Long); + auto tokens = from_blob(&cur_token, {1, 1}, token_type); auto start_pos_tensor = from_blob(&start_pos, {1}, exec_aten::ScalarType::Long); diff --git a/extension/llm/runner/text_prefiller.h b/extension/llm/runner/text_prefiller.h index 9dbaec40e63..a8d763737c8 100644 --- a/extension/llm/runner/text_prefiller.h +++ b/extension/llm/runner/text_prefiller.h @@ -24,6 +24,7 @@ class ET_EXPERIMENTAL TextPrefiller { public: TextPrefiller( TextDecoderRunner* text_decoder_runner, + bool use_int32_token, bool use_kv_cache_, bool enable_parallel_prefill); /** @@ -40,6 +41,7 @@ class ET_EXPERIMENTAL TextPrefiller { private: TextDecoderRunner* text_decoder_runner_; + bool use_int32_token_; bool use_kv_cache_; bool enable_parallel_prefill_; }; diff --git a/extension/llm/runner/text_token_generator.h b/extension/llm/runner/text_token_generator.h index 62b924a57d8..94dbd004aa8 100644 --- a/extension/llm/runner/text_token_generator.h +++ b/extension/llm/runner/text_token_generator.h @@ -23,12 +23,14 @@ class ET_EXPERIMENTAL TextTokenGenerator { TextTokenGenerator( Tokenizer* tokenizer, TextDecoderRunner* text_decoder_runner, + bool use_int32_token, bool use_kv_cache, std::unique_ptr>&& eos_ids, Stats* stats) : tokenizer_(tokenizer), text_decoder_runner_(text_decoder_runner), eos_ids_(std::move(eos_ids)), + use_int32_token_(use_int32_token), use_kv_cache_(use_kv_cache), stats_(stats) {} @@ -54,6 +56,9 @@ class ET_EXPERIMENTAL TextTokenGenerator { std::vector token_data; // allocate space for the tokens std::vector token_shape; + exec_aten::ScalarType token_type = use_int32_token_ + ? exec_aten::ScalarType::Int + : exec_aten::ScalarType::Long; // Token after prefill uint64_t cur_token = tokens.back(); @@ -70,8 +75,7 @@ class ET_EXPERIMENTAL TextTokenGenerator { } // initialize tensor wrappers - auto tokens_managed = from_blob( - token_data.data(), token_shape, executorch::aten::ScalarType::Long); + auto tokens_managed = from_blob(token_data.data(), token_shape, token_type); auto start_pos_managed = from_blob(&pos, {1}, executorch::aten::ScalarType::Long); @@ -133,6 +137,7 @@ class ET_EXPERIMENTAL TextTokenGenerator { Tokenizer* tokenizer_; TextDecoderRunner* text_decoder_runner_; std::unique_ptr> eos_ids_; + bool use_int32_token_; bool use_kv_cache_; // state machine From 5aac8b6a72b1dec93e1c026f8aea9fa40f870341 Mon Sep 17 00:00:00 2001 From: shewu-quic Date: Mon, 11 Nov 2024 09:27:23 +0800 Subject: [PATCH 2/2] set default value for use_int32_token --- examples/models/llama/runner/runner.cpp | 8 ++++---- extension/llm/runner/text_prefiller.cpp | 4 ++-- extension/llm/runner/text_prefiller.h | 4 ++-- extension/llm/runner/text_token_generator.h | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/models/llama/runner/runner.cpp b/examples/models/llama/runner/runner.cpp index 191282af418..b713ee6d697 100644 --- a/examples/models/llama/runner/runner.cpp +++ b/examples/models/llama/runner/runner.cpp @@ -129,17 +129,17 @@ Error Runner::load() { temperature_); text_prefiller_ = std::make_unique( text_decoder_runner_.get(), - metadata_.at(kUseInt32Token), metadata_.at(kUseKVCache), - metadata_.at(kEnableDynamicShape)); + metadata_.at(kEnableDynamicShape), + metadata_.at(kUseInt32Token)); text_token_generator_ = std::make_unique( tokenizer_.get(), text_decoder_runner_.get(), - metadata_.at(kUseInt32Token), metadata_.at(kUseKVCache), std::move(eos_ids), - &stats_); + &stats_, + metadata_.at(kUseInt32Token)); return Error::Ok; } diff --git a/extension/llm/runner/text_prefiller.cpp b/extension/llm/runner/text_prefiller.cpp index 8a5b6780e5a..219a229f565 100644 --- a/extension/llm/runner/text_prefiller.cpp +++ b/extension/llm/runner/text_prefiller.cpp @@ -17,9 +17,9 @@ namespace llm { TextPrefiller::TextPrefiller( TextDecoderRunner* text_decoder_runner, - bool use_int32_token, bool use_kv_cache, - bool enable_parallel_prefill) + bool enable_parallel_prefill, + bool use_int32_token) : text_decoder_runner_(text_decoder_runner), use_int32_token_(use_int32_token), use_kv_cache_(use_kv_cache), diff --git a/extension/llm/runner/text_prefiller.h b/extension/llm/runner/text_prefiller.h index a8d763737c8..872bc3fbfaf 100644 --- a/extension/llm/runner/text_prefiller.h +++ b/extension/llm/runner/text_prefiller.h @@ -24,9 +24,9 @@ class ET_EXPERIMENTAL TextPrefiller { public: TextPrefiller( TextDecoderRunner* text_decoder_runner, - bool use_int32_token, bool use_kv_cache_, - bool enable_parallel_prefill); + bool enable_parallel_prefill, + bool use_int32_token = false); /** * Prefill an LLM Module with the given text input. * @param prompt_tokens The text prompt tokens to the LLM Module. Encoded by diff --git a/extension/llm/runner/text_token_generator.h b/extension/llm/runner/text_token_generator.h index 94dbd004aa8..ca21a0c7db5 100644 --- a/extension/llm/runner/text_token_generator.h +++ b/extension/llm/runner/text_token_generator.h @@ -23,10 +23,10 @@ class ET_EXPERIMENTAL TextTokenGenerator { TextTokenGenerator( Tokenizer* tokenizer, TextDecoderRunner* text_decoder_runner, - bool use_int32_token, bool use_kv_cache, std::unique_ptr>&& eos_ids, - Stats* stats) + Stats* stats, + bool use_int32_token = false) : tokenizer_(tokenizer), text_decoder_runner_(text_decoder_runner), eos_ids_(std::move(eos_ids)),