Skip to content

Commit ebb9c3a

Browse files
rohansjoshifacebook-github-bot
authored andcommitted
QNN Llama Runner implement IRunner
Summary: This PR makes the Runner for running Qualcomm LlamaModels implement the IRunner interface Using this, enable running static Llama models with QNN backend inside LlamaDemo Android app Switched default eval mode to hybrid everywhere Reviewed By: cccclai Differential Revision: D79759817
1 parent c5db75b commit ebb9c3a

File tree

10 files changed

+120
-37
lines changed

10 files changed

+120
-37
lines changed

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ public class ModelUtils {
2121
// MediaTek
2222
static final int MEDIATEK_TEXT_MODEL = 3;
2323

24+
// QNN static llama
25+
static final int QNN_TEXT_MODEL = 4;
26+
2427
public static int getModelCategory(ModelType modelType, BackendType backendType) {
2528
if (backendType.equals(BackendType.XNNPACK)) {
2629
switch (modelType) {
@@ -35,6 +38,8 @@ public static int getModelCategory(ModelType modelType, BackendType backendType)
3538
}
3639
} else if (backendType.equals(BackendType.MEDIATEK)) {
3740
return MEDIATEK_TEXT_MODEL;
41+
} else if (backendType.equals(BackendType.QUALCOMM)) {
42+
return QNN_TEXT_MODEL;
3843
}
3944

4045
return TEXT_MODEL; // default

examples/models/llama/runner/runner.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <string>
1919
#include <unordered_map>
2020

21+
#include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
2122
#include <executorch/extension/llm/runner/irunner.h>
2223
#include <executorch/extension/llm/runner/text_llm_runner.h>
2324
#include <pytorch/tokenizers/tokenizer.h>
@@ -33,6 +34,7 @@ std::unique_ptr<llm::TextLLMRunner> create_llama_runner(
3334
float temperature = -1.0f);
3435

3536
std::unique_ptr<tokenizers::Tokenizer> load_llama_tokenizer(
36-
const std::string& tokenizer_path);
37+
const std::string& tokenizer_path,
38+
Version version = Version::Default);
3739

3840
} // namespace example

examples/qualcomm/oss_scripts/llama/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ list(
4242
${CMAKE_CURRENT_LIST_DIR}/runner/rpc_mem.h
4343
${CMAKE_CURRENT_LIST_DIR}/runner/kv_manager.cpp
4444
${CMAKE_CURRENT_LIST_DIR}/runner/kv_manager.h
45+
${EXECUTORCH_SOURCE_DIR}/examples/models/llama/runner/runner.cpp
46+
${EXECUTORCH_SOURCE_DIR}/examples/models/llama/runner/runner.h
4547
)
4648

4749
list(APPEND _llama_runner__srcs)

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1038,7 +1038,7 @@ def _build_parser():
10381038
parser.add_argument(
10391039
"--model_mode",
10401040
help="Export and inference kv mode, hybrid mode, or lookahead decoding mode",
1041-
default="kv",
1041+
default="hybrid",
10421042
choices=["kv", "hybrid", "lookahead"],
10431043
type=str,
10441044
)

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include <executorch/backends/qualcomm/runtime/QnnExecuTorch.h>
1818
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/runner.h>
19+
#include <executorch/extension/llm/runner/irunner.h>
1920
#include <executorch/runtime/platform/log.h>
2021
#include <gflags/gflags.h>
2122
#include <fstream>
@@ -61,7 +62,7 @@ DEFINE_int32(
6162
"Total number of tokens to generate (prompt + output).");
6263
DEFINE_int32(
6364
eval_mode,
64-
0,
65+
1,
6566
"0: TokenGenerator(kv) / 1: HybridMode (prefill+kv) / 2: Lookahead Decoding");
6667
DEFINE_string(
6768
kv_updater,
@@ -171,25 +172,26 @@ int main(int argc, char** argv) {
171172
buf.push_back(c);
172173
}
173174
};
174-
175+
executorch::extension::llm::GenerationConfig config{
176+
true,
177+
-1,
178+
false,
179+
FLAGS_seq_len,
180+
static_cast<float>(FLAGS_temperature),
181+
0,
182+
0};
175183
if (use_tokenized_prompt) {
176-
runner.generate(
177-
FLAGS_tokenized_prompt.c_str(),
178-
use_tokenized_prompt,
179-
FLAGS_seq_len,
180-
callback);
184+
runner.generate_from_prompt_or_file(
185+
FLAGS_tokenized_prompt.c_str(), use_tokenized_prompt, config, callback);
181186
} else {
182187
// generate tokens & store inference output
183188
for (int i = 0; i < FLAGS_num_iters; i++) {
184189
for (const auto& prompt : prompts) {
185190
std::string formatted_prompt;
186191
formatted_prompt = get_formatted_prompt(
187192
prompt, FLAGS_system_prompt, decoder_model_version.get());
188-
runner.generate(
189-
formatted_prompt.c_str(),
190-
use_tokenized_prompt,
191-
FLAGS_seq_len,
192-
callback);
193+
runner.generate_from_prompt_or_file(
194+
formatted_prompt.c_str(), use_tokenized_prompt, config, callback);
193195
}
194196
}
195197
}

examples/qualcomm/oss_scripts/llama/runner/runner.cpp

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
// A llama 3.2 runner that includes preprocessing and post processing
1010
// logic. The module takes in a string as input and emits a string as output.
1111

12+
#include <executorch/examples/models/llama/runner/runner.h>
1213
#include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
1314
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/client_mem.h>
1415
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h>
@@ -59,7 +60,7 @@ void print_performance_report(
5960
outfile << num_tok;
6061
outfile.close();
6162
} else {
62-
ET_CHECK_MSG(false, "Error saving the inference speed file");
63+
ET_LOG(Error, "Error saving the inference speed file");
6364
}
6465
}
6566

@@ -84,13 +85,6 @@ void save_logits(
8485

8586
} // namespace
8687

87-
std::unique_ptr<::tokenizers::Tokenizer> load_llama_tokenizer(
88-
const std::string& tokenizer_path,
89-
Version version) {
90-
auto special_tokens = get_special_tokens(version);
91-
return llm::load_tokenizer(tokenizer_path, std::move(special_tokens));
92-
}
93-
9488
Runner::Runner(
9589
const std::string& decoder_model_version,
9690
const std::string& model_path,
@@ -177,7 +171,8 @@ Error Runner::load() {
177171
eos_ids->insert(tokenizer_->encode("<|eot|>", 0, 0).get()[0]);
178172
eos_ids->insert(tokenizer_->encode("<|end_of_text|>", 0, 0).get()[0]);
179173
} else {
180-
tokenizer_ = load_llama_tokenizer(tokenizer_path_, Version::Default);
174+
tokenizer_ =
175+
example::load_llama_tokenizer(tokenizer_path_, Version::Default);
181176
if (tokenizer_ == nullptr) {
182177
ET_LOG(
183178
Error, "Failed to load tokenizer with %s", tokenizer_path_.c_str());
@@ -317,13 +312,30 @@ Error Runner::load() {
317312
}
318313

319314
Error Runner::generate(
315+
const std::string& prompt,
316+
const llm::GenerationConfig& config,
317+
std::function<void(const std::string&)> token_callback,
318+
std::function<void(const Stats&)> stats_callback) {
319+
return generate_from_pos(prompt, 0, config, token_callback, stats_callback);
320+
}
321+
322+
Error Runner::generate_from_pos(
323+
const std::string& prompt,
324+
int64_t start_pos,
325+
const llm::GenerationConfig& config,
326+
std::function<void(const std::string&)> token_callback,
327+
std::function<void(const Stats&)> stats_callback) {
328+
// TODO: currently only support start_pos == 0
329+
return generate_from_prompt_or_file(
330+
prompt, false, config, token_callback, stats_callback);
331+
}
332+
333+
Error Runner::generate_from_prompt_or_file(
320334
const std::string& prompt,
321335
bool tokenized_prompt,
322-
int32_t seq_len,
336+
const llm::GenerationConfig& config,
323337
std::function<void(const std::string&)> token_callback,
324-
std::function<void(const Stats&)> stats_callback,
325-
bool echo,
326-
bool warming) {
338+
std::function<void(const Stats&)> stats_callback) {
327339
ET_CHECK_MSG(!prompt.empty(), "prompt cannot be null");
328340
if (!is_loaded()) {
329341
stats_.model_load_start_ms = time_in_ms();
@@ -332,6 +344,7 @@ Error Runner::generate(
332344
}
333345
stats_.inference_start_ms = time_in_ms();
334346

347+
int32_t seq_len = config.seq_len;
335348
seq_len = (seq_len > 0 && seq_len <= context_len_) ? seq_len : context_len_;
336349
int32_t n_bos = (cur_pos_ == 0) ? 1 : 0;
337350

@@ -370,7 +383,7 @@ Error Runner::generate(
370383
"sequence length exceeded - please increase the seq_len value");
371384

372385
// Prompt Processor first
373-
if (token_callback) {
386+
if (token_callback && config.echo) {
374387
token_callback(prompt);
375388
}
376389
bool dump_logits = dump_logits_path_.empty() ? false : true;

examples/qualcomm/oss_scripts/llama/runner/runner.h

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h>
2222
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h>
2323
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/token_generator.h>
24+
#include <executorch/extension/llm/runner/irunner.h>
2425
#include <executorch/extension/llm/runner/stats.h>
2526
#include <executorch/extension/module/module.h>
2627
#include <pytorch/tokenizers/tokenizer.h>
@@ -33,7 +34,7 @@ enum DecoderModelVersion {
3334
kQwen2_5,
3435
kPhi4,
3536
};
36-
class Runner {
37+
class Runner : public executorch::extension::llm::IRunner {
3738
public:
3839
explicit Runner(
3940
const std::string& decoder_model,
@@ -42,25 +43,36 @@ class Runner {
4243
const std::string& performance_output_path,
4344
const std::string& dump_logits_path,
4445
const float temperature = 0.8f,
45-
const int eval_mode = EvalMode::kKVCached,
46+
const int eval_mode = EvalMode::kHybrid,
4647
const std::string& kv_updater = "SmartMask",
4748
const int ngram = 0,
4849
const int window = 0,
4950
const int gcap = 0,
5051
std::unique_ptr<tokenizers::Tokenizer> tokenizer = nullptr);
5152

52-
bool is_loaded() const;
53-
executorch::runtime::Error load();
53+
bool is_loaded() const override;
54+
executorch::runtime::Error load() override;
5455
// TODO: Support echo and warming
5556
executorch::runtime::Error generate(
57+
const std::string& prompt,
58+
const executorch::extension::llm::GenerationConfig& config,
59+
std::function<void(const std::string&)> token_callback = {},
60+
std::function<void(const executorch::llm::Stats&)> stats_callback = {})
61+
override;
62+
executorch::runtime::Error generate_from_pos(
63+
const std::string& prompt,
64+
int64_t start_pos,
65+
const executorch::extension::llm::GenerationConfig& config,
66+
std::function<void(const std::string&)> token_callback = {},
67+
std::function<void(const executorch::llm::Stats&)> stats_callback = {})
68+
override;
69+
executorch::runtime::Error generate_from_prompt_or_file(
5670
const std::string& prompt,
5771
bool tokenized_prompt,
58-
int32_t seq_len,
72+
const executorch::extension::llm::GenerationConfig& config,
5973
std::function<void(const std::string&)> token_callback = {},
60-
std::function<void(const executorch::llm::Stats&)> stats_callback = {},
61-
bool echo = true,
62-
bool warming = false);
63-
void stop() {};
74+
std::function<void(const executorch::llm::Stats&)> stats_callback = {});
75+
void stop() override {};
6476
executorch::runtime::Result<DecoderModelVersion> get_decoder_model_version();
6577

6678
private:

examples/qualcomm/oss_scripts/llama/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def define_common_targets():
2929
exported_deps = [
3030
"//executorch/extension/module:module",
3131
"//executorch/extension/llm/sampler:sampler",
32+
"//executorch/examples/models/llama/runner:runner",
3233
"//executorch/examples/models/llama/tokenizer:tiktoken",
3334
"//executorch/extension/evalue_util:print_evalue",
3435
"//executorch/backends/qualcomm/runtime:runtime",

extension/android/CMakeLists.txt

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,35 @@ if(EXECUTORCH_BUILD_LLAMA_JNI)
179179
${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llama/runner
180180
)
181181

182+
target_sources(
183+
executorch_jni PRIVATE
184+
${EXECUTORCH_ROOT}/extension/llm/runner/llm_runner_helper.cpp
185+
)
186+
187+
target_include_directories(
188+
executorch_jni PRIVATE
189+
${EXECUTORCH_ROOT}/extension/llm/runner
190+
)
191+
192+
if(QNN_SDK_ROOT)
193+
target_sources(
194+
executorch_jni PRIVATE
195+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/runner.cpp
196+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/decoder_runner.cpp
197+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp
198+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp
199+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp
200+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/rpc_mem.cpp
201+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp
202+
)
203+
204+
target_include_directories(
205+
executorch_jni PRIVATE
206+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner
207+
)
208+
target_compile_definitions(executorch_jni PRIVATE EXECUTORCH_BUILD_QNN=1)
209+
endif()
210+
182211
if(NEURON_BUFFER_ALLOCATOR_LIB)
183212
target_sources(
184213
executorch_jni

extension/android/jni/jni_layer_llama.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
#include <fbjni/ByteBuffer.h>
3030
#include <fbjni/fbjni.h>
3131

32+
#if defined(EXECUTORCH_BUILD_QNN)
33+
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/runner.h>
34+
#endif
35+
3236
#if defined(EXECUTORCH_BUILD_MEDIATEK)
3337
#include <executorch/examples/mediatek/executor_runner/mtk_llama_runner.h>
3438
#endif
@@ -124,6 +128,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
124128
constexpr static int MODEL_TYPE_CATEGORY_LLM = 1;
125129
constexpr static int MODEL_TYPE_CATEGORY_MULTIMODAL = 2;
126130
constexpr static int MODEL_TYPE_MEDIATEK_LLAMA = 3;
131+
constexpr static int MODEL_TYPE_QNN_LLAMA = 4;
127132

128133
static facebook::jni::local_ref<jhybriddata> initHybrid(
129134
facebook::jni::alias_ref<jclass>,
@@ -174,6 +179,17 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
174179
model_path->toStdString(),
175180
tokenizer_path->toStdString(),
176181
data_path_str);
182+
#if defined(EXECUTORCH_BUILD_QNN)
183+
} else if (model_type_category == MODEL_TYPE_QNN_LLAMA) {
184+
std::string decoder_model = "llama3"; // use llama3 for now
185+
runner_ = std::make_unique<example::Runner>( // QNN runner
186+
decoder_model.c_str(),
187+
model_path->toStdString().c_str(),
188+
tokenizer_path->toStdString().c_str(),
189+
data_path->toStdString().c_str(),
190+
"");
191+
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
192+
#endif
177193
#if defined(EXECUTORCH_BUILD_MEDIATEK)
178194
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
179195
runner_ = std::make_unique<MTKLlamaRunner>(
@@ -318,6 +334,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
318334
[callback](std::string result) { callback->onResult(result); },
319335
[callback](const llm::Stats& stats) { callback->onStats(stats); }));
320336
}
337+
return static_cast<jint>(executorch::runtime::Error::InvalidArgument);
321338
}
322339

323340
void stop() {

0 commit comments

Comments
 (0)