Skip to content

Commit 0fc31e2

Browse files
rohansjoshifacebook-github-bot
authored andcommitted
QNN Llama Runner implement IRunner (#13171)
Summary: This PR makes the Runner for running Qualcomm LlamaModels implement the IRunner interface Using this, enable running static Llama models inside LlamaDemo Android app Switched default eval mode to hybrid everywhere Differential Revision: D79759817
1 parent afdbb85 commit 0fc31e2

File tree

7 files changed

+87
-24
lines changed

7 files changed

+87
-24
lines changed

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ public class ModelUtils {
2121
// MediaTek
2222
static final int MEDIATEK_TEXT_MODEL = 3;
2323

24+
// QNN static llama
25+
static final int QNN_TEXT_MODEL = 4;
26+
2427
public static int getModelCategory(ModelType modelType, BackendType backendType) {
2528
if (backendType.equals(BackendType.XNNPACK)) {
2629
switch (modelType) {
@@ -35,6 +38,8 @@ public static int getModelCategory(ModelType modelType, BackendType backendType)
3538
}
3639
} else if (backendType.equals(BackendType.MEDIATEK)) {
3740
return MEDIATEK_TEXT_MODEL;
41+
} else if (backendType.equals(BackendType.QUALCOMM)) {
42+
return QNN_TEXT_MODEL;
3843
}
3944

4045
return TEXT_MODEL; // default

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1038,7 +1038,7 @@ def _build_parser():
10381038
parser.add_argument(
10391039
"--model_mode",
10401040
help="Export and inference kv mode, hybrid mode, or lookahead decoding mode",
1041-
default="kv",
1041+
default="hybrid",
10421042
choices=["kv", "hybrid", "lookahead"],
10431043
type=str,
10441044
)

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include <executorch/backends/qualcomm/runtime/QnnExecuTorch.h>
1818
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/runner.h>
19+
#include <executorch/extension/llm/runner/irunner.h>
1920
#include <executorch/runtime/platform/log.h>
2021
#include <gflags/gflags.h>
2122
#include <fstream>
@@ -61,7 +62,7 @@ DEFINE_int32(
6162
"Total number of tokens to generate (prompt + output).");
6263
DEFINE_int32(
6364
eval_mode,
64-
0,
65+
1,
6566
"0: TokenGenerator(kv) / 1: HybridMode (prefill+kv) / 2: Lookahead Decoding");
6667
DEFINE_string(
6768
kv_updater,
@@ -161,12 +162,13 @@ int main(int argc, char** argv) {
161162
buf.push_back(c);
162163
}
163164
};
164-
165+
executorch::extension::llm::GenerationConfig config{
166+
true, -1, false, FLAGS_seq_len, static_cast<float>(FLAGS_temperature), 0, 0};
165167
if (use_tokenized_prompt) {
166-
runner.generate(
168+
runner.generate_from_prompt_or_file(
167169
FLAGS_tokenized_prompt.c_str(),
168170
use_tokenized_prompt,
169-
FLAGS_seq_len,
171+
config,
170172
callback);
171173
} else {
172174
// generate tokens & store inference output
@@ -175,10 +177,10 @@ int main(int argc, char** argv) {
175177
std::string formatted_prompt;
176178
formatted_prompt = get_formatted_prompt(
177179
prompt, FLAGS_system_prompt, decoder_model_version.get());
178-
runner.generate(
180+
runner.generate_from_prompt_or_file(
179181
formatted_prompt.c_str(),
180182
use_tokenized_prompt,
181-
FLAGS_seq_len,
183+
config,
182184
callback);
183185
}
184186
}

examples/qualcomm/oss_scripts/llama/runner/runner.cpp

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ void print_performance_report(
5959
outfile << num_tok;
6060
outfile.close();
6161
} else {
62-
ET_CHECK_MSG(false, "Error saving the inference speed file");
62+
ET_LOG(Error, "Error saving the inference speed file");
6363
}
6464
}
6565

@@ -84,7 +84,7 @@ void save_logits(
8484

8585
} // namespace
8686

87-
std::unique_ptr<::tokenizers::Tokenizer> load_llama_tokenizer(
87+
std::unique_ptr<::tokenizers::Tokenizer> load_qnn_llama_tokenizer(
8888
const std::string& tokenizer_path,
8989
Version version) {
9090
auto special_tokens = get_special_tokens(version);
@@ -175,7 +175,7 @@ Error Runner::load() {
175175
eos_ids->insert(tokenizer_->encode("<|eot|>", 0, 0).get()[0]);
176176
eos_ids->insert(tokenizer_->encode("<|end_of_text|>", 0, 0).get()[0]);
177177
} else {
178-
tokenizer_ = load_llama_tokenizer(tokenizer_path_, Version::Default);
178+
tokenizer_ = load_qnn_llama_tokenizer(tokenizer_path_, Version::Default);
179179
if (tokenizer_ == nullptr) {
180180
ET_LOG(
181181
Error, "Failed to load tokenizer with %s", tokenizer_path_.c_str());
@@ -313,13 +313,29 @@ Error Runner::load() {
313313
}
314314

315315
Error Runner::generate(
316+
const std::string& prompt,
317+
const executorch::extension::llm::GenerationConfig& config,
318+
std::function<void(const std::string&)> token_callback,
319+
std::function<void(const Stats&)> stats_callback) {
320+
return generate_from_pos(prompt, 0, config, token_callback, stats_callback);
321+
}
322+
323+
Error Runner::generate_from_pos(
324+
const std::string& prompt,
325+
int64_t start_pos,
326+
const executorch::extension::llm::GenerationConfig& config,
327+
std::function<void(const std::string&)> token_callback,
328+
std::function<void(const Stats&)> stats_callback) {
329+
// TODO: currently only support start_pos == 0
330+
return generate_tokenized_prompt_option(prompt, false, config, token_callback, stats_callback);
331+
}
332+
333+
Error Runner::generate_from_prompt_or_file(
316334
const std::string& prompt,
317335
bool tokenized_prompt,
318-
int32_t seq_len,
336+
const executorch::extension::llm::GenerationConfig& config,
319337
std::function<void(const std::string&)> token_callback,
320-
std::function<void(const Stats&)> stats_callback,
321-
bool echo,
322-
bool warming) {
338+
std::function<void(const Stats&)> stats_callback) {
323339
ET_CHECK_MSG(!prompt.empty(), "prompt cannot be null");
324340
if (!is_loaded()) {
325341
stats_.model_load_start_ms = time_in_ms();
@@ -328,6 +344,7 @@ Error Runner::generate(
328344
}
329345
stats_.inference_start_ms = time_in_ms();
330346

347+
int32_t seq_len = config.seq_len;
331348
seq_len = (seq_len > 0 && seq_len <= context_len_) ? seq_len : context_len_;
332349
int32_t n_bos = (cur_pos_ == 0) ? 1 : 0;
333350

@@ -366,7 +383,7 @@ Error Runner::generate(
366383
"sequence length exceeded - please increase the seq_len value");
367384

368385
// Prompt Processor first
369-
if (token_callback) {
386+
if (token_callback && config.echo) {
370387
token_callback(prompt);
371388
}
372389
bool dump_logits = dump_logits_path_.empty() ? false : true;

examples/qualcomm/oss_scripts/llama/runner/runner.h

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h>
2222
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h>
2323
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/token_generator.h>
24+
#include <executorch/extension/llm/runner/irunner.h>
2425
#include <executorch/extension/llm/runner/stats.h>
2526
#include <executorch/extension/module/module.h>
2627
#include <pytorch/tokenizers/tokenizer.h>
@@ -32,7 +33,7 @@ enum DecoderModelVersion {
3233
kLlama3,
3334
kQwen2_5,
3435
};
35-
class Runner {
36+
class Runner : public executorch::extension::llm::IRunner {
3637
public:
3738
explicit Runner(
3839
const std::string& decoder_model,
@@ -41,25 +42,34 @@ class Runner {
4142
const std::string& performance_output_path,
4243
const std::string& dump_logits_path,
4344
const float temperature = 0.8f,
44-
const int eval_mode = EvalMode::kKVCached,
45+
const int eval_mode = EvalMode::kHybrid,
4546
const std::string& kv_updater = "SmartMask",
4647
const int ngram = 0,
4748
const int window = 0,
4849
const int gcap = 0,
4950
std::unique_ptr<tokenizers::Tokenizer> tokenizer = nullptr);
5051

51-
bool is_loaded() const;
52-
executorch::runtime::Error load();
52+
bool is_loaded() const override;
53+
executorch::runtime::Error load() override;
5354
// TODO: Support echo and warming
5455
executorch::runtime::Error generate(
56+
const std::string& prompt,
57+
const executorch::extension::llm::GenerationConfig& config,
58+
std::function<void(const std::string&)> token_callback = {},
59+
std::function<void(const executorch::llm::Stats&)> stats_callback = {}) override;
60+
executorch::runtime::Error generate_from_pos(
61+
const std::string& prompt,
62+
int64_t start_pos,
63+
const executorch::extension::llm::GenerationConfig& config,
64+
std::function<void(const std::string&)> token_callback = {},
65+
std::function<void(const executorch::llm::Stats&)> stats_callback = {}) override;
66+
executorch::runtime::Error generate_from_prompt_or_file(
5567
const std::string& prompt,
5668
bool tokenized_prompt,
57-
int32_t seq_len,
69+
const executorch::extension::llm::GenerationConfig& config,
5870
std::function<void(const std::string&)> token_callback = {},
59-
std::function<void(const executorch::llm::Stats&)> stats_callback = {},
60-
bool echo = true,
61-
bool warming = false);
62-
void stop() {};
71+
std::function<void(const executorch::llm::Stats&)> stats_callback = {});
72+
void stop() override {};
6373
executorch::runtime::Result<DecoderModelVersion> get_decoder_model_version();
6474

6575
private:

extension/android/CMakeLists.txt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,24 @@ if(EXECUTORCH_BUILD_LLAMA_JNI)
166166
${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llama/runner
167167
)
168168

169+
if (QNN_SDK_ROOT)
170+
target_sources(
171+
executorch_jni PRIVATE
172+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/runner.cpp
173+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/decoder_runner.cpp
174+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp
175+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp
176+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp
177+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/rpc_mem.cpp
178+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp
179+
)
180+
181+
target_include_directories(
182+
executorch_jni PRIVATE
183+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner
184+
)
185+
endif()
186+
169187
if(NEURON_BUFFER_ALLOCATOR_LIB)
170188
target_sources(
171189
executorch_jni PRIVATE

extension/android/jni/jni_layer_llama.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include <executorch/examples/models/llama/runner/runner.h>
1717
#include <executorch/examples/models/llava/runner/llava_runner.h>
18+
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/runner.h>
1819
#include <executorch/extension/llm/runner/image.h>
1920
#include <executorch/extension/llm/runner/irunner.h>
2021
#include <executorch/runtime/platform/log.h>
@@ -124,6 +125,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
124125
constexpr static int MODEL_TYPE_CATEGORY_LLM = 1;
125126
constexpr static int MODEL_TYPE_CATEGORY_MULTIMODAL = 2;
126127
constexpr static int MODEL_TYPE_MEDIATEK_LLAMA = 3;
128+
constexpr static int MODEL_TYPE_QNN_LLAMA = 4;
127129

128130
static facebook::jni::local_ref<jhybriddata> initHybrid(
129131
facebook::jni::alias_ref<jclass>,
@@ -174,6 +176,14 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
174176
model_path->toStdString(),
175177
tokenizer_path->toStdString(),
176178
data_path_str);
179+
} else if (model_type_category == MODEL_TYPE_QNN_LLAMA) {
180+
std::string decoder_model = "llama3"; // use llama3 for now
181+
runner_ = std::make_unique<example::Runner>( // QNN runner
182+
decoder_model.c_str(),
183+
model_path->toStdString().c_str(),
184+
tokenizer_path->toStdString().c_str(),
185+
data_path->toStdString().c_str());
186+
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
177187
#if defined(EXECUTORCH_BUILD_MEDIATEK)
178188
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
179189
runner_ = std::make_unique<MTKLlamaRunner>(
@@ -318,6 +328,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
318328
[callback](std::string result) { callback->onResult(result); },
319329
[callback](const llm::Stats& stats) { callback->onStats(stats); }));
320330
}
331+
return static_cast<jint>(executorch::runtime::Error::InvalidArgument);
321332
}
322333

323334
void stop() {

0 commit comments

Comments
 (0)