Skip to content

Commit 73e2f7c

Browse files
committed
Qualcomm AI Engine Direct - Optimize QNN embedding op for llama
summary: - Change the dtype of the token from int64 to in32 Int32 is Qnn HTP friendly. It will significantly speed up qnn embedding operations due to matching backend optimizations.
1 parent 39e5b91 commit 73e2f7c

File tree

5 files changed

+30
-7
lines changed

5 files changed

+30
-7
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager:
519519
checkpoint=checkpoint_path,
520520
checkpoint_dir=checkpoint_dir,
521521
params_path=params_path,
522+
use_int32_token=True if args.qnn else False,
522523
use_kv_cache=args.use_kv_cache,
523524
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
524525
generate_full_logits=args.generate_full_logits,
@@ -746,6 +747,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
746747

747748
def _load_llama_model_metadata(
748749
weight_type: WeightType,
750+
use_int32_token: bool,
749751
use_kv_cache: bool,
750752
use_sdpa_with_kv_cache: bool,
751753
enable_dynamic_shape: bool,
@@ -759,6 +761,7 @@ def _load_llama_model_metadata(
759761
"get_max_seq_len": model_args.max_seq_len,
760762
"get_n_layers": model_args.n_layers,
761763
"get_vocab_size": model_args.vocab_size,
764+
"use_int32_token": use_int32_token,
762765
"use_kv_cache": use_kv_cache,
763766
"use_sdpa_with_kv_cache": use_sdpa_with_kv_cache,
764767
"enable_dynamic_shape": enable_dynamic_shape,
@@ -779,6 +782,7 @@ def _load_llama_model(
779782
checkpoint: Optional[str] = None,
780783
checkpoint_dir: Optional[str] = None,
781784
params_path: str,
785+
use_int32_token: bool = False,
782786
use_kv_cache: bool = False,
783787
use_sdpa_with_kv_cache: bool = False,
784788
generate_full_logits: bool = False,
@@ -852,6 +856,10 @@ def _load_llama_model(
852856
else:
853857
raise ValueError(f"Unsupported dtype {dtype}")
854858

859+
if use_int32_token:
860+
token = example_inputs[0].to(torch.int32)
861+
example_inputs = (token,) + example_inputs[1:]
862+
855863
return LLMEdgeManager(
856864
model=model,
857865
modelname=modelname,
@@ -870,6 +878,7 @@ def _load_llama_model(
870878
verbose=verbose,
871879
metadata=_load_llama_model_metadata(
872880
weight_type,
881+
use_int32_token,
873882
use_kv_cache,
874883
use_sdpa_with_kv_cache,
875884
enable_dynamic_shape,

examples/models/llama/runner/runner.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ static constexpr auto kMaxSeqLen = "get_max_seq_len";
3434
static constexpr auto kVocabSize = "get_vocab_size";
3535
static constexpr auto kUseKVCache = "use_kv_cache";
3636
static constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
37+
static constexpr auto kUseInt32Token = "use_int32_token";
3738
} // namespace
3839

3940
Runner::Runner(
@@ -51,6 +52,7 @@ Runner::Runner(
5152
{kMaxSeqLen, 128},
5253
{kUseKVCache, true},
5354
{kUseSDPAWithKVCache, false},
55+
{kUseInt32Token, true},
5456
}) {
5557
ET_LOG(
5658
Info,
@@ -127,12 +129,14 @@ Error Runner::load() {
127129
temperature_);
128130
text_prefiller_ = std::make_unique<llm::TextPrefiller>(
129131
text_decoder_runner_.get(),
132+
metadata_.at(kUseInt32Token),
130133
metadata_.at(kUseKVCache),
131134
metadata_.at(kEnableDynamicShape));
132135

133136
text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
134137
tokenizer_.get(),
135138
text_decoder_runner_.get(),
139+
metadata_.at(kUseInt32Token),
136140
metadata_.at(kUseKVCache),
137141
std::move(eos_ids),
138142
&stats_);

extension/llm/runner/text_prefiller.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@ namespace llm {
1717

1818
TextPrefiller::TextPrefiller(
1919
TextDecoderRunner* text_decoder_runner,
20+
bool use_int32_token,
2021
bool use_kv_cache,
2122
bool enable_parallel_prefill)
2223
: text_decoder_runner_(text_decoder_runner),
24+
use_int32_token_(use_int32_token),
2325
use_kv_cache_(use_kv_cache),
2426
enable_parallel_prefill_(enable_parallel_prefill) {}
2527

@@ -36,12 +38,13 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
3638

3739
// store the token
3840
uint64_t cur_token;
41+
exec_aten::ScalarType token_type = use_int32_token_
42+
? exec_aten::ScalarType::Int
43+
: exec_aten::ScalarType::Long;
3944
if (enable_parallel_prefill_ || !use_kv_cache_) {
4045
// initialize tensor wrappers
41-
auto tokens = from_blob(
42-
prompt_tokens.data(),
43-
{1, num_prompt_tokens},
44-
exec_aten::ScalarType::Long);
46+
auto tokens =
47+
from_blob(prompt_tokens.data(), {1, num_prompt_tokens}, token_type);
4548

4649
auto start_pos_tensor =
4750
from_blob(&start_pos, {1}, exec_aten::ScalarType::Long);
@@ -60,7 +63,7 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
6063
cur_token = prompt_tokens[0];
6164

6265
// initialize tensor wrappers
63-
auto tokens = from_blob(&cur_token, {1, 1}, exec_aten::ScalarType::Long);
66+
auto tokens = from_blob(&cur_token, {1, 1}, token_type);
6467

6568
auto start_pos_tensor =
6669
from_blob(&start_pos, {1}, exec_aten::ScalarType::Long);

extension/llm/runner/text_prefiller.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class ET_EXPERIMENTAL TextPrefiller {
2424
public:
2525
TextPrefiller(
2626
TextDecoderRunner* text_decoder_runner,
27+
bool use_int32_token,
2728
bool use_kv_cache_,
2829
bool enable_parallel_prefill);
2930
/**
@@ -40,6 +41,7 @@ class ET_EXPERIMENTAL TextPrefiller {
4041

4142
private:
4243
TextDecoderRunner* text_decoder_runner_;
44+
bool use_int32_token_;
4345
bool use_kv_cache_;
4446
bool enable_parallel_prefill_;
4547
};

extension/llm/runner/text_token_generator.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@ class ET_EXPERIMENTAL TextTokenGenerator {
2323
TextTokenGenerator(
2424
Tokenizer* tokenizer,
2525
TextDecoderRunner* text_decoder_runner,
26+
bool use_int32_token,
2627
bool use_kv_cache,
2728
std::unique_ptr<std::unordered_set<uint64_t>>&& eos_ids,
2829
Stats* stats)
2930
: tokenizer_(tokenizer),
3031
text_decoder_runner_(text_decoder_runner),
3132
eos_ids_(std::move(eos_ids)),
33+
use_int32_token_(use_int32_token),
3234
use_kv_cache_(use_kv_cache),
3335
stats_(stats) {}
3436

@@ -54,6 +56,9 @@ class ET_EXPERIMENTAL TextTokenGenerator {
5456

5557
std::vector<uint64_t> token_data; // allocate space for the tokens
5658
std::vector<executorch::aten::SizesType> token_shape;
59+
exec_aten::ScalarType token_type = use_int32_token_
60+
? exec_aten::ScalarType::Int
61+
: exec_aten::ScalarType::Long;
5762

5863
// Token after prefill
5964
uint64_t cur_token = tokens.back();
@@ -70,8 +75,7 @@ class ET_EXPERIMENTAL TextTokenGenerator {
7075
}
7176

7277
// initialize tensor wrappers
73-
auto tokens_managed = from_blob(
74-
token_data.data(), token_shape, executorch::aten::ScalarType::Long);
78+
auto tokens_managed = from_blob(token_data.data(), token_shape, token_type);
7579
auto start_pos_managed =
7680
from_blob(&pos, {1}, executorch::aten::ScalarType::Long);
7781

@@ -133,6 +137,7 @@ class ET_EXPERIMENTAL TextTokenGenerator {
133137
Tokenizer* tokenizer_;
134138
TextDecoderRunner* text_decoder_runner_;
135139
std::unique_ptr<std::unordered_set<uint64_t>> eos_ids_;
140+
bool use_int32_token_;
136141
bool use_kv_cache_;
137142

138143
// state machine

0 commit comments

Comments
 (0)