diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index c468247b98a..5b69ae5ac3c 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -158,7 +158,6 @@ def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict): def annotate_matmul_16a8w( # noqa: C901 gm: torch.fx.GraphModule, - annotate_conv=True, is_qat=False, ) -> None: """ @@ -337,10 +336,9 @@ def annotate_matmul_input1(node: Node, is_qat: str): # The arguments of cat op: (the past kv cache, the new kv cache) node = node.args[0][1] elif node.target == torch.ops.aten.conv2d.default: - if annotate_conv: - annotate_conv2d( - node, quantization_config=quantization_config_8a4w_per_channel - ) + annotate_conv2d( + node, quantization_config=quantization_config_8a4w_per_channel + ) break elif node.target in [torch.ops.aten.add.Tensor, torch.ops.aten.sub.Tensor]: break diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index a4b0841ac3d..f7ded652799 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -4560,6 +4560,8 @@ def test_static_qwen2_5(self): "wikitext", "--limit", "1", + "--r3", + "--enable_masked_softmax", ] if self.compile_only: cmds.extend(["--compile_only"]) @@ -4581,7 +4583,7 @@ def test_static_qwen2_5(self): self.fail(msg["Error"]) else: inference_speed_ref = {"SM8650": 110, "SM8750": 130} - self.assertLessEqual(msg["wiki_ppl"], 25) + self.assertLessEqual(msg["wiki_ppl"], 15) self.assertLessEqual(msg["pte_size"], 800000000) # 800mb if self.model in inference_speed_ref: self.assertGreaterEqual( diff --git a/examples/qualcomm/oss_scripts/llama/README.md b/examples/qualcomm/oss_scripts/llama/README.md index fea550bb51b..a45c0756f1b 100644 --- a/examples/qualcomm/oss_scripts/llama/README.md +++ b/examples/qualcomm/oss_scripts/llama/README.md @@ -59,13 +59,19 @@ At the end of this step, users should have the following files ready: `consolida ### Step3: Run default examples using hybrid mode. #### LLAMA2 ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint stories110M.pt --params params.json --tokenizer_model tokenizer.model --tokenizer_bin tokenizer.bin --llama_model stories110m --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "Once upon a time" +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint stories110M.pt --params params.json --tokenizer_model tokenizer.model --tokenizer_bin tokenizer.bin --decoder_model stories110m --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "Once upon a time" ``` #### LLAMA3.2 Default example using hybrid mode. ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" +``` + +#### QWEN2.5 0.5B +Default example using hybrid mode +```bash +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --temperature 0 --model_mode hybrid --max_seq_len 1024 --prefill_ar_len 128 --ptq 16a8w --enable_masked_softmax --r3 --decoder_model qwen2_5 --prompt "I would like to learn python, could you teach me with a simple example?" ``` ### KV Cache update mechanism @@ -120,13 +126,13 @@ We have two distinct mechanisms for updating the key-value (KV) cache, which can #### Compile Only If you would like to compile the model only, we have provided the flag `--compile_only`. Taking LLAMA3.2 as an example: ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --compile_only +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --compile_only ``` #### Pre Generated PTE On the other hand, if you already have a pre-compiled .pte model, you can perform inference by providing the flag `--pre_gen_pte` and specifying the folder that contains the .pte model. Taking LLAMA3.2 as an example: ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --pre_gen_pte ${FOLDER_TO_PRE_GEN_PTE} +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --pre_gen_pte ${FOLDER_TO_PRE_GEN_PTE} ``` #### KV Cache Updater @@ -134,7 +140,7 @@ python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL You can select the KV Cache update mechanism at runtime by setting the `KV_UPDATER` variable to either "shift_pointer" or "smart_mask". By default, it is set to "smart_mask". `KV_UPDATER` = "shift_pointer" ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --kv_updator ${KV_UPDATER} +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --kv_updator ${KV_UPDATER} ``` #### Lookahead Decoding Mode @@ -147,7 +153,7 @@ You can choose the lookahead mode to enhance decoding speed. To use this mode, y For more details, please refer to the paper ["Break the Sequential Dependency of LLM Inference Using Lookahead Decoding"](https://arxiv.org/abs/2402.02057) ```bash -python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode lookahead --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --ngram 3 --window 2 --gcap 2 +python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --decoder_model llama3_2 --model_mode lookahead --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --ngram 3 --window 2 --gcap 2 ``` #### Masked Softmax diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index e36b3442100..3988ea33c4e 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -264,8 +264,8 @@ def quantize( self.llama_graph_module = convert_pt2e(fx_graph_module) - logging.info("Verifying the QDQ model...") if args.eval_perplexity: + logging.info("Verifying the QDQ model...") # Check qdq cpu results graph_module_inference( args=args, @@ -362,6 +362,7 @@ def compile(args, pte_filename, tokenizer): kv_config.use_kv_cache = True kv_config.enable_masked_softmax = args.enable_masked_softmax kv_config.enable_r3 = args.r3 + kv_config.kv_io_bit_width = 16 if args.ptq == "16a8w" else 8 prefill_config = copy.copy(kv_config) prefill_config.use_kv_cache = ( @@ -535,11 +536,15 @@ def permute(w, heads): fixed_point_type = {"kv_type": torch.float32, "io_type": torch.float32} if args.ptq: use_fp16 = False - fixed_point_type["kv_type"] = torch.uint8 if args.ptq == "8a8w": fixed_point_type["io_type"] = torch.uint8 - elif args.ptq in ("16a4w", "16a4w_block", "16a8w"): + fixed_point_type["kv_type"] = torch.uint8 + elif args.ptq in ("16a4w", "16a4w_block"): fixed_point_type["io_type"] = torch.uint16 + fixed_point_type["kv_type"] = torch.uint8 + elif args.ptq == "16a8w": + fixed_point_type["io_type"] = torch.uint16 + fixed_point_type["kv_type"] = torch.uint16 else: assert args.ptq in [ "8a8w", @@ -572,13 +577,10 @@ def permute(w, heads): if args.ptq: start_quantize_ts = time.time() - custom_annotations = ( - # For qwen2.5, skip annotate_conv can improve result. - partial( - annotate_matmul_16a8w, - annotate_conv=args.ptq != "16a8w", - ), - ) + custom_annotations = () + if args.ptq != "16a8w": + # 16a8w use 16bit kv io, so skip this custom annotation + custom_annotations = custom_annotations + (annotate_matmul_16a8w,) if args.decoder_model in {"stories110m", "stories260k"}: custom_annotations = custom_annotations + ( annotate_linear_16a8w_in_affine_layer, diff --git a/examples/qualcomm/oss_scripts/llama/model/static_llama.py b/examples/qualcomm/oss_scripts/llama/model/static_llama.py index d1063d053b4..83b2777d14c 100755 --- a/examples/qualcomm/oss_scripts/llama/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama/model/static_llama.py @@ -444,6 +444,7 @@ def __init__( self.output_new_cache_only = output_new_cache_only self.use_i64_token = use_i64_token self.output_cache = output_cache + self.kv_io_bit_width = config.kv_io_bit_width self.layers = nn.ModuleList( [ @@ -607,4 +608,5 @@ def get_metadata(self): "get_n_layers": self.n_layers, "get_vocab_size": self.vocab_size, "get_use_kv_cache": self.use_kv_cache, + "get_kv_io_bit_width": self.kv_io_bit_width, } diff --git a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp index 78e6a0a4245..6afeca0ca95 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp @@ -133,24 +133,16 @@ std::string get_formatted_prompt( return formatted_prompt; } -int main(int argc, char** argv) { - std::vector prompts = CollectPrompts(argc, argv); - gflags::ParseCommandLineFlags(&argc, &argv, true); - if (!gflags::GetCommandLineFlagInfoOrDie("prompt").is_default && - !gflags::GetCommandLineFlagInfoOrDie("tokenized_prompt").is_default) { - ET_CHECK_MSG(false, "Only provide prompt or tokenized_input but not both."); - } - if (!gflags::GetCommandLineFlagInfoOrDie("dump_logits_path").is_default && - FLAGS_eval_mode != 0) { - ET_CHECK_MSG( - false, "Only TokenGenerator(kv) mode is supported to dump all logits."); - } - +template +void start_runner( + std::unique_ptr module, + std::vector& prompts) { bool use_tokenized_prompt = gflags::GetCommandLineFlagInfoOrDie("tokenized_prompt").is_default ? false : true; // create llama runner - example::Runner runner( + example::Runner runner( + std::move(module), FLAGS_decoder_model_version.c_str(), FLAGS_model_path.c_str(), FLAGS_tokenizer_path.c_str(), @@ -196,5 +188,43 @@ int main(int argc, char** argv) { fout.write(buf.data(), buf.size()); fout.close(); +} + +int main(int argc, char** argv) { + std::vector prompts = CollectPrompts(argc, argv); + gflags::ParseCommandLineFlags(&argc, &argv, true); + if (!gflags::GetCommandLineFlagInfoOrDie("prompt").is_default && + !gflags::GetCommandLineFlagInfoOrDie("tokenized_prompt").is_default) { + ET_CHECK_MSG(false, "Only provide prompt or tokenized_input but not both."); + } + if (!gflags::GetCommandLineFlagInfoOrDie("dump_logits_path").is_default && + FLAGS_eval_mode != 0) { + ET_CHECK_MSG( + false, "Only TokenGenerator(kv) mode is supported to dump all logits."); + } + + std::unique_ptr module = + std::make_unique( + FLAGS_model_path.c_str(), + executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors); + // Using 8bit as default since this meta is introduced with 16bit kv io + // support and older models only have 8bit kv io. + example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8; + if (module->method_names()->count("get_kv_io_bit_width") > 0) { + kv_bitwidth = static_cast( + module->get("get_kv_io_bit_width").get().toScalar().to()); + } + + if (kv_bitwidth == example::KvBitWidth::kWidth8) { + start_runner(std::move(module), prompts); + } else if (kv_bitwidth == example::KvBitWidth::kWidth16) { + start_runner(std::move(module), prompts); + } else { + ET_CHECK_MSG( + false, + "Unsupported kv bitwidth: %ld", + static_cast(kv_bitwidth)); + } + return 0; } diff --git a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp index b563049eb8d..9ce1abafa04 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp @@ -9,34 +9,35 @@ #include #include namespace example { -KVManager::KVManager(KVManagerMode kv_updater, Metadata metadata) +template +KVManager::KVManager(KVManagerMode kv_updater, Metadata metadata) : kv_updater_(kv_updater), metadata_(metadata) { k_cache_.resize( - metadata_.num_layers, std::vector(metadata_.num_heads)); + metadata_.num_layers, std::vector>(metadata_.num_heads)); v_cache_.resize( - metadata_.num_layers, std::vector(metadata_.num_heads)); + metadata_.num_layers, std::vector>(metadata_.num_heads)); // Calculate cache size switch (kv_updater_) { case KVManagerMode::SMART_MASK: { size_t cache_in_bytes = metadata_.num_layers * metadata_.num_heads * - metadata_.head_dim * metadata_.max_cache_len * sizeof(uint8_t); + metadata_.head_dim * metadata_.max_cache_len * sizeof(T); size_t cache_out_bytes = metadata_.num_layers * metadata_.num_heads * - metadata_.head_dim * metadata_.max_ar_len * sizeof(uint8_t); + metadata_.head_dim * metadata_.max_ar_len * sizeof(T); total_cache_size_ = 2 * (cache_in_bytes + cache_out_bytes); break; } case KVManagerMode::SHIFT_POINTER: { size_t k_cache_in_bytes = metadata_.num_layers * metadata_.num_heads * - (metadata_.head_dim + 1) * metadata_.max_cache_len * sizeof(uint8_t); + (metadata_.head_dim + 1) * metadata_.max_cache_len * sizeof(T); size_t k_cache_out_bytes = metadata_.num_layers * metadata_.num_heads * - metadata_.head_dim * metadata_.max_ar_len * sizeof(uint8_t); + metadata_.head_dim * metadata_.max_ar_len * sizeof(T); // Use the same memory for input and output of value cache in shift // pointer mode. Note that using context length to prevent exceeding the // range when the AR-N model updates the last block in shift pointer // mode. size_t v_cache_bytes = metadata_.num_layers * (metadata_.num_heads + 1) * - metadata_.head_dim * metadata_.context_len * sizeof(uint8_t); + metadata_.head_dim * metadata_.context_len * sizeof(T); total_cache_size_ = k_cache_in_bytes + k_cache_out_bytes + v_cache_bytes; break; } @@ -45,7 +46,8 @@ KVManager::KVManager(KVManagerMode kv_updater, Metadata metadata) } }; -void KVManager::init_attention_mask( +template +void KVManager::init_attention_mask( uint16_t* attention_mask, const std::vector& attention_map, int32_t ar_len, @@ -114,7 +116,8 @@ void KVManager::init_attention_mask( } } -void KVManager::update_attention_mask( +template +void KVManager::update_attention_mask( uint16_t* attention_mask, int32_t ar_len, int32_t n_past, @@ -132,12 +135,12 @@ void KVManager::update_attention_mask( } } -void KVManager::init_cache(IMemAlloc* buffer_manager, int32_t ar_len) { +template +void KVManager::init_cache(IMemAlloc* buffer_manager, int32_t ar_len) { cur_ar_len_ = ar_len; const size_t max_in_cache_block_in_bytes = - metadata_.max_cache_len * sizeof(uint8_t); - const size_t max_out_cache_block_in_bytes = - metadata_.max_ar_len * sizeof(uint8_t); + metadata_.max_cache_len * sizeof(T); + const size_t max_out_cache_block_in_bytes = metadata_.max_ar_len * sizeof(T); switch (kv_updater_) { case KVManagerMode::SMART_MASK: { @@ -148,14 +151,14 @@ void KVManager::init_cache(IMemAlloc* buffer_manager, int32_t ar_len) { for (int layer = 0; layer < metadata_.num_layers; ++layer) { for (int head = 0; head < metadata_.num_heads; ++head) { // Allocate buffer for key cache and value cache - uint8_t* single_layer_k_cache_in = reinterpret_cast( - buffer_manager->allocate(cache_in_bytes)); - uint8_t* single_layer_k_cache_out = reinterpret_cast( - buffer_manager->allocate(cache_out_bytes)); - uint8_t* single_layer_v_cache_in = reinterpret_cast( - buffer_manager->allocate(cache_in_bytes)); - uint8_t* single_layer_v_cache_out = reinterpret_cast( - buffer_manager->allocate(cache_out_bytes)); + T* single_layer_k_cache_in = + reinterpret_cast(buffer_manager->allocate(cache_in_bytes)); + T* single_layer_k_cache_out = + reinterpret_cast(buffer_manager->allocate(cache_out_bytes)); + T* single_layer_v_cache_in = + reinterpret_cast(buffer_manager->allocate(cache_in_bytes)); + T* single_layer_v_cache_out = + reinterpret_cast(buffer_manager->allocate(cache_out_bytes)); k_cache_[layer][head].buffer = single_layer_k_cache_in; k_cache_[layer][head].output_buffer = single_layer_k_cache_out; @@ -171,20 +174,20 @@ void KVManager::init_cache(IMemAlloc* buffer_manager, int32_t ar_len) { const size_t k_cache_out_size_in_bytes = metadata_.num_heads * metadata_.head_dim * max_out_cache_block_in_bytes; const size_t v_cache_size_in_bytes = (metadata_.num_heads + 1) * - metadata_.head_dim * metadata_.context_len * sizeof(uint8_t); + metadata_.head_dim * metadata_.context_len * sizeof(T); const int32_t single_head_size_in = metadata_.head_dim * metadata_.max_cache_len; const int32_t single_head_size_out = metadata_.head_dim * metadata_.max_ar_len; for (int layer = 0; layer < metadata_.num_layers; ++layer) { // Allocate buffer for key cache and value cache - uint8_t* single_layer_k_cache_in = reinterpret_cast( + T* single_layer_k_cache_in = reinterpret_cast( buffer_manager->allocate(k_cache_in_size_in_bytes)); - uint8_t* single_layer_k_cache_out = reinterpret_cast( + T* single_layer_k_cache_out = reinterpret_cast( buffer_manager->allocate(k_cache_out_size_in_bytes)); // Note that using context length to prevent exceeding the range when // the AR-N model updates the last block in shift pointer mode. - uint8_t* single_layer_v_cache = reinterpret_cast( + T* single_layer_v_cache = reinterpret_cast( buffer_manager->allocate(v_cache_size_in_bytes)); for (int head = 0; head < metadata_.num_heads; ++head) { k_cache_[layer][head].buffer = single_layer_k_cache_in + @@ -211,7 +214,8 @@ void KVManager::init_cache(IMemAlloc* buffer_manager, int32_t ar_len) { } } -void KVManager::rearrange_cache(int32_t ar_len_dst) { +template +void KVManager::rearrange_cache(int32_t ar_len_dst) { // Don't need to rearrange if cur_ar_len_ is equal to target ar_len if (cur_ar_len_ == ar_len_dst) return; @@ -225,15 +229,16 @@ void KVManager::rearrange_cache(int32_t ar_len_dst) { cur_ar_len_ = ar_len_dst; } -void KVManager::rearrange_key(KVCache& k_cache, int32_t ar_len_dst) { +template +void KVManager::rearrange_key(KVCache& k_cache, int32_t ar_len_dst) { // The output of key cache doesn't need to rearrange for both of SMART_MASK // and SHIFT_POINTER const int32_t src_cache_num = (cur_ar_len_ == metadata_.context_len) ? metadata_.context_len : metadata_.context_len - cur_ar_len_; const int32_t dst_cache_num = metadata_.context_len - ar_len_dst; - uint8_t* k_cache_in_read_ptr = k_cache.buffer; - uint8_t* k_cache_in_write_ptr = k_cache.buffer; + T* k_cache_in_read_ptr = k_cache.buffer; + T* k_cache_in_write_ptr = k_cache.buffer; if (src_cache_num > dst_cache_num) { if (kv_updater_ == KVManagerMode::SHIFT_POINTER) { @@ -263,7 +268,8 @@ void KVManager::rearrange_key(KVCache& k_cache, int32_t ar_len_dst) { } } -void KVManager::rearrange_value(KVCache& v_cache, int32_t ar_len_dst) { +template +void KVManager::rearrange_value(KVCache& v_cache, int32_t ar_len_dst) { // The input and output of the value cache don't need to rearrange for both // SMART_MASK and SHIFT_POINTER. However, the input pointer of the value cache // needs to be reset by ar_len_dst in SHIFT_POINTER mode. The output pointer @@ -276,7 +282,8 @@ void KVManager::rearrange_value(KVCache& v_cache, int32_t ar_len_dst) { } } -bool KVManager::update_cache_tensor( +template +bool KVManager::update_cache_tensor( std::vector>>& k_cache_in, std::vector>>& @@ -313,7 +320,8 @@ bool KVManager::update_cache_tensor( return updated; } -void KVManager::update_cache( +template +void KVManager::update_cache( int32_t ar_len, int32_t n_past, int32_t n_update, @@ -331,14 +339,15 @@ void KVManager::update_cache( } } -void KVManager::update_key( - KVCache& k_cache, +template +void KVManager::update_key( + KVCache& k_cache, int32_t n_past, int32_t n_update, const std::vector& selected) { - uint8_t* write_ptr = k_cache.buffer; - uint8_t* read_ptr = k_cache.output_buffer; - const int32_t copy_size = n_update * sizeof(uint8_t); + T* write_ptr = k_cache.buffer; + T* read_ptr = k_cache.output_buffer; + const int32_t copy_size = n_update * sizeof(T); const int32_t iter_size = (cur_ar_len_ == metadata_.context_len) ? metadata_.context_len : metadata_.context_len - cur_ar_len_; @@ -374,14 +383,15 @@ void KVManager::update_key( } } -void KVManager::update_value( - KVCache& v_cache, +template +void KVManager::update_value( + KVCache& v_cache, int32_t n_past, int32_t n_update, const std::vector& selected) { - uint8_t* write_ptr = v_cache.buffer; - uint8_t* read_ptr = v_cache.output_buffer; - const int32_t copy_size = n_update * metadata_.head_dim * sizeof(uint8_t); + T* write_ptr = v_cache.buffer; + T* read_ptr = v_cache.output_buffer; + const int32_t copy_size = n_update * metadata_.head_dim * sizeof(T); const int32_t past_size = n_past * metadata_.head_dim; if (kv_updater_ == KVManagerMode::SMART_MASK) @@ -403,7 +413,7 @@ void KVManager::update_value( auto wp = write_ptr, rp = read_ptr; for (auto sel : selected) { if (sel) { - std::memcpy(wp, rp, metadata_.head_dim * sizeof(uint8_t)); + std::memcpy(wp, rp, metadata_.head_dim * sizeof(T)); wp += metadata_.head_dim; update_times--; if (update_times == 0) @@ -414,4 +424,8 @@ void KVManager::update_value( } } +// Explicit instantiations +template class KVManager; +template class KVManager; + } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h index e1a756d1215..c20a5a1ab60 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h +++ b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h @@ -15,9 +15,10 @@ namespace example { // Structure to hold key-value cache buffers +template struct KVCache { - uint8_t* buffer; - uint8_t* output_buffer; + T* buffer; + T* output_buffer; }; // Enumeration for key-value manager modes @@ -26,6 +27,7 @@ enum KVManagerMode { SMART_MASK = 0x0, SHIFT_POINTER = 0x1 }; * @class KVManager * @brief Class for kv cache update, rearrangement, and buffer allocatation. */ +template class KVManager { public: struct Metadata { @@ -128,10 +130,10 @@ class KVManager { int32_t n_update, const std::vector& selected); - const std::vector>& get_k_cache_() const { + const std::vector>>& get_k_cache_() const { return k_cache_; } - const std::vector>& get_v_cache_() const { + const std::vector>>& get_v_cache_() const { return v_cache_; } @@ -141,15 +143,15 @@ class KVManager { private: // Helper functions to rearrange and update key and value caches - void rearrange_key(KVCache& k_cache, int32_t ar_len_dst); - void rearrange_value(KVCache& v_cache, int32_t ar_len_dst); + void rearrange_key(KVCache& k_cache, int32_t ar_len_dst); + void rearrange_value(KVCache& v_cache, int32_t ar_len_dst); void update_key( - KVCache& k_cache, + KVCache& k_cache, int32_t n_past, int32_t n_update, const std::vector& selected); void update_value( - KVCache& v_cache, + KVCache& v_cache, int32_t n_past, int32_t n_update, const std::vector& selected); @@ -162,7 +164,7 @@ class KVManager { // Store start pointer of k and v cache for input and output // input: layer -> head -> head_dim * max_cache_len // output: layer -> head -> head_dim * max_ar_len - std::vector> k_cache_; - std::vector> v_cache_; + std::vector>> k_cache_; + std::vector>> v_cache_; }; } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp index 9b5030c461c..1692caa2756 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp @@ -13,28 +13,31 @@ using executorch::runtime::Result; namespace example { -void LhdTokenGenerator::prepare_io( +template +void LhdTokenGenerator::prepare_io( std::vector input_tokens, std::vector input_pos) { for (int i = 0; i < metadata_.ar_len; i++) { if (i < input_tokens.size()) { // Prepare pos data - input_pos_.data[i] = input_pos[i]; + this->input_pos_.data[i] = input_pos[i]; // Support CPU 4-bit embedding, which requires int64 input. // However, for QNN embedding, only int32 input is needed. // Therefore, we need to cast to the correct type to write the data. if (metadata_.use_int64_token) { - input_toks_.data[i] = input_tokens[i]; + this->input_toks_.data[i] = input_tokens[i]; } else { - int32_t* input_toks_ptr = reinterpret_cast(input_toks_.data); + int32_t* input_toks_ptr = + reinterpret_cast(this->input_toks_.data); input_toks_ptr[i] = static_cast(input_tokens[i]); } } } } -void LhdTokenGenerator::init_attention_mask(int32_t n_past) { +template +void LhdTokenGenerator::init_attention_mask(int32_t n_past) { std::vector attention_map; attention_map.reserve(metadata_.ar_len); // Initialize attention mask with current position @@ -56,11 +59,12 @@ void LhdTokenGenerator::init_attention_mask(int32_t n_past) { } } - kv_manager_->init_attention_mask( - attention_mask_.data, attention_map, metadata_.ar_len, n_past); + this->kv_manager_->init_attention_mask( + this->attention_mask_.data, attention_map, metadata_.ar_len, n_past); } -void LhdTokenGenerator::init_lookahead_branch( +template +void LhdTokenGenerator::init_lookahead_branch( const std::vector& tokens) { for (int i = 0; i < metadata_.ngram - 1; ++i) { for (int j = 0; j < metadata_.window; ++j) { @@ -77,7 +81,8 @@ void LhdTokenGenerator::init_lookahead_branch( is_lhd_branch_initialized_ = true; } -void LhdTokenGenerator::init_verification_branch(uint64_t cur_token) { +template +void LhdTokenGenerator::init_verification_branch(uint64_t cur_token) { const int g_cur = ngrams_pool_.cnt[cur_token]; v_branch_.resize(g_cur); @@ -101,7 +106,8 @@ void LhdTokenGenerator::init_verification_branch(uint64_t cur_token) { } } -void LhdTokenGenerator::update_ngrams_pool() { +template +void LhdTokenGenerator::update_ngrams_pool() { std::vector ngram(metadata_.ngram - 1); // n-gram pool generation for (int f = 0; f < metadata_.window; ++f) { @@ -154,7 +160,8 @@ void LhdTokenGenerator::update_ngrams_pool() { } } -void LhdTokenGenerator::update_lookahead_branch( +template +void LhdTokenGenerator::update_lookahead_branch( const executorch::aten::Tensor& logits_tensor) { for (int i = 0; i < metadata_.window; i++) { lhd_branch_prev_[i] = lhd_branch_[0][i]; @@ -168,11 +175,12 @@ void LhdTokenGenerator::update_lookahead_branch( for (int i = 0; i < metadata_.window; i++) { size_t sample_idx = (metadata_.ngram - 2) * metadata_.window + i; lhd_branch_[metadata_.ngram - 2][i] = - decoder_runner_->logits_to_token(logits_tensor, sample_idx); + this->decoder_runner_->logits_to_token(logits_tensor, sample_idx); } } -Result LhdTokenGenerator::generate( +template +Result LhdTokenGenerator::generate( std::vector tokens, int64_t start_pos, int32_t seq_len, @@ -197,7 +205,7 @@ Result LhdTokenGenerator::generate( input_pos.reserve(metadata_.ar_len); // Rearrange KV cache first and initialize the input and output of KV cache - kv_manager_->rearrange_cache(metadata_.ar_len); + this->kv_manager_->rearrange_cache(metadata_.ar_len); // Initialize attention mask with pos init_attention_mask(pos); @@ -210,10 +218,11 @@ Result LhdTokenGenerator::generate( // Initialize the output of the module ET_CHECK_MSG( - decoder_runner_->set_outputs(method_name_, output_tensors_) == + this->decoder_runner_->set_outputs( + this->method_name_, this->output_tensors_) == executorch::runtime::Error::Ok, "Failed to set output tensor for module %s", - method_name_.c_str()); + this->method_name_.c_str()); // Generate tokens while (pos < seq_len - 1) { @@ -252,25 +261,27 @@ Result LhdTokenGenerator::generate( prepare_io(input_tokens, input_pos); // Only update data pointer of the cache to the tensor for SHIFT_POINTER // mode - bool updated = kv_manager_->update_cache_tensor( - k_cache_in_, - k_cache_out_, - v_cache_in_, - v_cache_out_, + bool updated = this->kv_manager_->update_cache_tensor( + this->k_cache_in_, + this->k_cache_out_, + this->v_cache_in_, + this->v_cache_out_, metadata_.ar_len, pos); // Only update the output of module for SHIFT_POINTER mode if (updated) { // Update the output of the module ET_CHECK_MSG( - decoder_runner_->set_outputs(method_name_, output_tensors_) == + this->decoder_runner_->set_outputs( + this->method_name_, this->output_tensors_) == executorch::runtime::Error::Ok, "Failed to set output tensor for module %s", - method_name_.c_str()); + this->method_name_.c_str()); } // Run inference - auto logits_res = decoder_runner_->step(method_name_, inputs_); + auto logits_res = + this->decoder_runner_->step(this->method_name_, this->inputs_); ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error()); executorch::aten::Tensor& logits_tensor = logits_res.get(); prev_pos = pos; @@ -313,18 +324,19 @@ Result LhdTokenGenerator::generate( prev_token = cur_token; // sampler from logits all - stats_->on_sampling_begin(); - cur_token = decoder_runner_->logits_to_token(logits_tensor, sample_idx); - stats_->on_sampling_end(); + this->stats_->on_sampling_begin(); + cur_token = + this->decoder_runner_->logits_to_token(logits_tensor, sample_idx); + this->stats_->on_sampling_end(); result_tokens.push_back(cur_token); pos++; // print the token as string, decode it with the Tokenizer object token_callback( - ET_UNWRAP_TOKENIZER(tokenizer_->decode(prev_token, cur_token))); + ET_UNWRAP_TOKENIZER(this->tokenizer_->decode(prev_token, cur_token))); // data-dependent terminating condition: we have n_eos_ number of EOS - if (eos_ids_->count(cur_token) > 0) { + if (this->eos_ids_->count(cur_token) > 0) { printf("\n"); ET_LOG(Info, "\nReached to the end of generation"); break; @@ -360,14 +372,15 @@ Result LhdTokenGenerator::generate( } // Update KV Cache with the output results int32_t n_update = pos - prev_pos; - kv_manager_->update_cache(metadata_.ar_len, prev_pos, n_update, selected); + this->kv_manager_->update_cache( + metadata_.ar_len, prev_pos, n_update, selected); // Update attention mask with current position - kv_manager_->update_attention_mask( - attention_mask_.data, metadata_.ar_len, prev_pos, n_update); + this->kv_manager_->update_attention_mask( + this->attention_mask_.data, metadata_.ar_len, prev_pos, n_update); // data-dependent terminating condition: we have n_eos_ number of EOS - if (eos_ids_->count(cur_token) > 0) { + if (this->eos_ids_->count(cur_token) > 0) { printf("\n"); ET_LOG(Info, "\nReached to the end of generation"); break; @@ -381,4 +394,9 @@ Result LhdTokenGenerator::generate( return pos - start_pos; } + +// Explicit instantiations +template class LhdTokenGenerator; +template class LhdTokenGenerator; + } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h index fde50972f06..174c7f7504f 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h +++ b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h @@ -15,7 +15,8 @@ namespace example { * @brief Class for generating the token using decoder and key-value manager * with lookahead decoding. */ -class LhdTokenGenerator : public TokenGenerator { +template +class LhdTokenGenerator : public TokenGenerator { public: struct Metadata { int32_t context_len; @@ -31,18 +32,18 @@ class LhdTokenGenerator : public TokenGenerator { LhdTokenGenerator( tokenizers::Tokenizer* tokenizer, DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& forward_name, std::unique_ptr>&& eos_ids, Metadata metadata, executorch::llm::Stats* stats) - : TokenGenerator( + : TokenGenerator( tokenizer, decoder_runner, kv_manager, forward_name, std::move(eos_ids), - TokenGenerator::Metadata{ + typename TokenGenerator::Metadata{ metadata.context_len, metadata.num_heads, metadata.num_layers, diff --git a/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp b/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp index 8794a1651da..787185c2249 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp @@ -14,9 +14,11 @@ using executorch::runtime::Result; using executorch::runtime::TensorInfo; namespace example { -PromptProcessor::PromptProcessor( + +template +PromptProcessor::PromptProcessor( DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, Metadata metadata) : decoder_runner_(decoder_runner), @@ -37,7 +39,9 @@ PromptProcessor::PromptProcessor( metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); logits_.size = metadata_.ar_len * metadata_.vocab_size * sizeof(uint16_t); }; -void PromptProcessor::init_io( + +template +void PromptProcessor::init_io( IMemAlloc* buffer_manager, Result method_meta) { input_tensors_.reserve(method_meta->num_inputs()); @@ -91,14 +95,14 @@ void PromptProcessor::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>>& cache = (cache_group == 0 ? k_cache_in_ : v_cache_in_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector>> cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer) { for (int head = 0; head < metadata_.num_heads; ++head, ++index) { Result kv_cache = method_meta->input_tensor_meta(index); - uint8_t* cache_ptr = cache_ptrs[layer][head].buffer; + T* cache_ptr = cache_ptrs[layer][head].buffer; cache[layer].emplace_back(std::make_unique( kv_cache->scalar_type(), @@ -133,13 +137,13 @@ void PromptProcessor::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>>& cache = (cache_group == 0 ? k_cache_out_ : v_cache_out_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector>> cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer) { for (int head = 0; head < metadata_.num_heads; ++head, ++index) { Result kv_cache = method_meta->output_tensor_meta(index); - uint8_t* cache_ptr = cache_ptrs[layer][head].output_buffer; + T* cache_ptr = cache_ptrs[layer][head].output_buffer; cache[layer].emplace_back(std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), @@ -160,11 +164,13 @@ void PromptProcessor::init_io( } } -const std::vector& PromptProcessor::get_all_logits() { +template +const std::vector& PromptProcessor::get_all_logits() { return prompt_all_logits_; } -void PromptProcessor::prepare_io( +template +void PromptProcessor::prepare_io( const std::vector& prompt_tokens, int64_t prompt_pos, int64_t start_pos) { @@ -189,7 +195,8 @@ void PromptProcessor::prepare_io( } } -Result PromptProcessor::prefill( +template +Result PromptProcessor::prefill( std::vector prompt_tokens, int64_t start_pos, bool dump_logits) { @@ -281,4 +288,8 @@ Result PromptProcessor::prefill( return cur_token; } +// Explicit instantiations +template class PromptProcessor; +template class PromptProcessor; + } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h b/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h index 244e26577e9..04945558ae5 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h +++ b/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h @@ -19,6 +19,7 @@ namespace example { * @class PromptProcessor * @brief Class for processing prompts using decoder and key-value manager. */ +template class PromptProcessor { public: struct Metadata { @@ -31,7 +32,7 @@ class PromptProcessor { }; PromptProcessor( DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, Metadata metadata); @@ -92,7 +93,7 @@ class PromptProcessor { int64_t prompt_pos, int64_t start_pos); DecoderRunner* decoder_runner_; - KVManager* kv_manager_; + KVManager* kv_manager_; std::string method_name_; // metadata diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index 6f4a57880b0..df2e2d96041 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -21,7 +21,6 @@ #include #include #include - #include #include @@ -91,7 +90,9 @@ std::unique_ptr<::tokenizers::Tokenizer> load_llama_tokenizer( return llm::load_tokenizer(tokenizer_path, std::move(special_tokens)); } -Runner::Runner( +template +Runner::Runner( + std::unique_ptr module, const std::string& decoder_model_version, const std::string& model_path, const std::string& tokenizer_path, @@ -104,7 +105,8 @@ Runner::Runner( const int window, const int gcap, std::unique_ptr tokenizer) - : ngram_(ngram), + : module_(std::move(module)), + ngram_(ngram), window_(window), gcap_(gcap), tokenizer_path_(tokenizer_path), @@ -113,8 +115,6 @@ Runner::Runner( temperature_(temperature), eval_mode_(static_cast(eval_mode)), tokenizer_(std::move(tokenizer)) { - module_ = std::make_unique( - model_path, Module::LoadMode::MmapUseMlockIgnoreErrors); stats_.reset(); if (kv_updater == "SmartMask") { kv_updater_ = KVManagerMode::SMART_MASK; @@ -142,12 +142,14 @@ Runner::Runner( ET_LOG(Info, "kv updater=%s", kv_updater.c_str()); } -bool Runner::is_loaded() const { +template +bool Runner::is_loaded() const { return module_->is_loaded() && tokenizer_ && decoder_runner_ && prompt_processor_ && token_generator_ && kv_manager_ && buffer_manager_; } -Error Runner::load() { +template +Error Runner::load() { if (is_loaded()) { return Error::Ok; } @@ -207,6 +209,7 @@ Error Runner::load() { // retrieve any method meta, can be either prefill or kv int64_t num_layers = ET_UNWRAP(module_->get("get_n_layers")).toScalar().to(); + ET_CHECK_MSG(num_layers != -1, "Could not retrieve num layers"); // k_cache: [1, head_dim, seq_len] int64_t head_dim = method_meta->output_tensor_meta(1)->sizes()[1]; @@ -241,9 +244,9 @@ Error Runner::load() { std::min(token_generator_ar_len, prompt_processor_ar_len); max_ar_len = std::max(token_generator_ar_len, prompt_processor_ar_len); - kv_manager_ = std::make_unique( + kv_manager_ = std::make_unique>( kv_updater_, - KVManager::Metadata{ + typename KVManager::Metadata{ context_len_, head_dim, max_ar_len, @@ -251,11 +254,11 @@ Error Runner::load() { num_heads, num_layers}); - prompt_processor_ = std::make_unique( + prompt_processor_ = std::make_unique>( decoder_runner_.get(), kv_manager_.get(), prompt_processor_method_name, - PromptProcessor::Metadata{ + typename PromptProcessor::Metadata{ context_len_, num_heads, num_layers, @@ -263,13 +266,13 @@ Error Runner::load() { vocab_size, use_int64_token}); if (eval_mode_ == EvalMode::kLookaheadDecoding) { - token_generator_ = std::make_unique( + token_generator_ = std::make_unique>( tokenizer_.get(), decoder_runner_.get(), kv_manager_.get(), token_generator_method_name, std::move(eos_ids), - LhdTokenGenerator::Metadata{ + typename LhdTokenGenerator::Metadata{ context_len_, num_heads, num_layers, @@ -281,13 +284,13 @@ Error Runner::load() { gcap_}, &stats_); } else { - token_generator_ = std::make_unique( + token_generator_ = std::make_unique>( tokenizer_.get(), decoder_runner_.get(), kv_manager_.get(), token_generator_method_name, std::move(eos_ids), - TokenGenerator::Metadata{ + typename TokenGenerator::Metadata{ context_len_, num_heads, num_layers, @@ -316,7 +319,8 @@ Error Runner::load() { return Error::Ok; } -Error Runner::generate( +template +Error Runner::generate( const std::string& prompt, bool tokenized_prompt, int32_t seq_len, @@ -422,7 +426,8 @@ Error Runner::generate( return Error::Ok; } -Result Runner::get_decoder_model_version() { +template +Result Runner::get_decoder_model_version() { if (!is_loaded()) { stats_.model_load_start_ms = time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(load()); @@ -431,4 +436,8 @@ Result Runner::get_decoder_model_version() { return decoder_model_version_; } +// Explicit instantiations +template class Runner; +template class Runner; + } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.h b/examples/qualcomm/oss_scripts/llama/runner/runner.h index fe59049a9d8..6cc1f68d9a8 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.h @@ -33,9 +33,17 @@ enum DecoderModelVersion { kQwen2_5, kPhi4, }; + +enum KvBitWidth { + kWidth8 = 8, + kWidth16 = 16, +}; + +template class Runner { public: explicit Runner( + std::unique_ptr module, const std::string& decoder_model, const std::string& model_path, const std::string& tokenizer_path, @@ -87,11 +95,11 @@ class Runner { DecoderModelVersion decoder_model_version_; KVManagerMode kv_updater_; std::unique_ptr buffer_manager_; - std::unique_ptr kv_manager_; + std::unique_ptr> kv_manager_; std::unique_ptr tokenizer_; std::unique_ptr decoder_runner_; - std::unique_ptr prompt_processor_; - std::unique_ptr token_generator_; + std::unique_ptr> prompt_processor_; + std::unique_ptr> token_generator_; // stats executorch::llm::Stats stats_; diff --git a/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp b/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp index bacff94f594..b04d3e4486d 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp @@ -14,10 +14,11 @@ using executorch::runtime::Result; using executorch::runtime::TensorInfo; namespace example { -TokenGenerator::TokenGenerator( +template +TokenGenerator::TokenGenerator( tokenizers::Tokenizer* tokenizer, DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, std::unique_ptr>&& eos_ids, Metadata metadata, @@ -41,7 +42,9 @@ TokenGenerator::TokenGenerator( metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); logits_.size = metadata_.ar_len * metadata_.vocab_size * sizeof(uint16_t); } -void TokenGenerator::init_io( + +template +void TokenGenerator::init_io( IMemAlloc* buffer_manager, Result method_meta) { input_tensors_.reserve(method_meta->num_inputs()); @@ -94,14 +97,14 @@ void TokenGenerator::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>>& cache = (cache_group == 0 ? k_cache_in_ : v_cache_in_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector>> cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer) { for (int head = 0; head < metadata_.num_heads; ++head, ++index) { Result kv_cache = method_meta->input_tensor_meta(index); - uint8_t* cache_ptr = cache_ptrs[layer][head].buffer; + T* cache_ptr = cache_ptrs[layer][head].buffer; cache[layer].emplace_back(std::make_unique( kv_cache->scalar_type(), @@ -135,13 +138,13 @@ void TokenGenerator::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>>& cache = (cache_group == 0 ? k_cache_out_ : v_cache_out_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector>> cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer) { for (int head = 0; head < metadata_.num_heads; ++head, ++index) { Result kv_cache = method_meta->output_tensor_meta(index); - uint8_t* cache_ptr = cache_ptrs[layer][head].output_buffer; + T* cache_ptr = cache_ptrs[layer][head].output_buffer; cache[layer].emplace_back(std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), @@ -162,12 +165,14 @@ void TokenGenerator::init_io( } } -const std::vector& TokenGenerator::get_all_logits() { +template +const std::vector& TokenGenerator::get_all_logits() { return token_all_logits_; } // This function only considers the case where token_generator_ar_len equals 1. -void TokenGenerator::prepare_io(uint64_t cur_token, int64_t start_pos) { +template +void TokenGenerator::prepare_io(uint64_t cur_token, int64_t start_pos) { // update input_tok *input_toks_.data = metadata_.use_int64_token ? cur_token : static_cast(cur_token); @@ -175,7 +180,8 @@ void TokenGenerator::prepare_io(uint64_t cur_token, int64_t start_pos) { *input_pos_.data = static_cast(start_pos); } -Result TokenGenerator::generate( +template +Result TokenGenerator::generate( std::vector tokens, int64_t start_pos, int32_t seq_len, @@ -261,4 +267,9 @@ Result TokenGenerator::generate( } return pos - start_pos; } + +// Explicit instantiations +template class TokenGenerator; +template class TokenGenerator; + } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/token_generator.h b/examples/qualcomm/oss_scripts/llama/runner/token_generator.h index f76340d4d87..682c1531b88 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/token_generator.h +++ b/examples/qualcomm/oss_scripts/llama/runner/token_generator.h @@ -20,6 +20,7 @@ namespace example { * @class TokenGenerator * @brief Class for generating the token using decoder and key-value manager. */ +template class TokenGenerator { public: struct Metadata { @@ -33,7 +34,7 @@ class TokenGenerator { TokenGenerator( tokenizers::Tokenizer* tokenizer, DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, std::unique_ptr>&& eos_ids, Metadata metadata, @@ -79,7 +80,7 @@ class TokenGenerator { protected: tokenizers::Tokenizer* tokenizer_; DecoderRunner* decoder_runner_; - KVManager* kv_manager_; + KVManager* kv_manager_; std::string method_name_; std::unique_ptr> eos_ids_;