Skip to content

Commit 53146a4

Browse files
authored
QNN Llama Runner implement IRunner (#13171)
Summary: This PR makes the Runner for running Qualcomm LlamaModels implement the IRunner interface Using this, enable running static Llama models inside LlamaDemo Android app Switched default eval mode to hybrid everywhere Differential Revision: D79759817
1 parent cf4f3b9 commit 53146a4

File tree

9 files changed

+127
-36
lines changed

9 files changed

+127
-36
lines changed

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ public class ModelUtils {
2121
// MediaTek
2222
static final int MEDIATEK_TEXT_MODEL = 3;
2323

24+
// QNN static llama
25+
static final int QNN_TEXT_MODEL = 4;
26+
2427
public static int getModelCategory(ModelType modelType, BackendType backendType) {
2528
if (backendType.equals(BackendType.XNNPACK)) {
2629
switch (modelType) {
@@ -35,6 +38,8 @@ public static int getModelCategory(ModelType modelType, BackendType backendType)
3538
}
3639
} else if (backendType.equals(BackendType.MEDIATEK)) {
3740
return MEDIATEK_TEXT_MODEL;
41+
} else if (backendType.equals(BackendType.QUALCOMM)) {
42+
return QNN_TEXT_MODEL;
3843
}
3944

4045
return TEXT_MODEL; // default

examples/models/llama/runner/runner.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <string>
1919
#include <unordered_map>
2020

21+
#include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
2122
#include <executorch/extension/llm/runner/irunner.h>
2223
#include <executorch/extension/llm/runner/text_llm_runner.h>
2324
#include <pytorch/tokenizers/tokenizer.h>
@@ -33,6 +34,7 @@ std::unique_ptr<llm::TextLLMRunner> create_llama_runner(
3334
float temperature = -1.0f);
3435

3536
std::unique_ptr<tokenizers::Tokenizer> load_llama_tokenizer(
36-
const std::string& tokenizer_path);
37+
const std::string& tokenizer_path,
38+
Version version = Version::Default);
3739

3840
} // namespace example

examples/qualcomm/oss_scripts/llama/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ list(
4242
${CMAKE_CURRENT_LIST_DIR}/runner/rpc_mem.h
4343
${CMAKE_CURRENT_LIST_DIR}/runner/kv_manager.cpp
4444
${CMAKE_CURRENT_LIST_DIR}/runner/kv_manager.h
45+
${EXECUTORCH_SOURCE_DIR}/examples/models/llama/runner/runner.cpp
46+
${EXECUTORCH_SOURCE_DIR}/examples/models/llama/runner/runner.h
4547
)
4648

4749
list(APPEND _llama_runner__srcs)

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include <executorch/backends/qualcomm/runtime/QnnExecuTorch.h>
1818
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/runner.h>
19+
#include <executorch/extension/llm/runner/irunner.h>
1920
#include <executorch/runtime/platform/log.h>
2021
#include <gflags/gflags.h>
2122
#include <fstream>
@@ -61,7 +62,7 @@ DEFINE_int32(
6162
"Total number of tokens to generate (prompt + output).");
6263
DEFINE_int32(
6364
eval_mode,
64-
0,
65+
1,
6566
"0: TokenGenerator(kv) / 1: HybridMode (prefill+kv) / 2: Lookahead Decoding");
6667
DEFINE_string(
6768
kv_updater,
@@ -172,25 +173,26 @@ void start_runner(
172173
buf.push_back(c);
173174
}
174175
};
175-
176+
executorch::extension::llm::GenerationConfig config{
177+
true,
178+
-1,
179+
false,
180+
FLAGS_seq_len,
181+
static_cast<float>(FLAGS_temperature),
182+
0,
183+
0};
176184
if (use_tokenized_prompt) {
177-
runner.generate(
178-
FLAGS_tokenized_prompt.c_str(),
179-
use_tokenized_prompt,
180-
FLAGS_seq_len,
181-
callback);
185+
runner.generate_from_prompt_or_file(
186+
FLAGS_tokenized_prompt.c_str(), use_tokenized_prompt, config, callback);
182187
} else {
183188
// generate tokens & store inference output
184189
for (int i = 0; i < FLAGS_num_iters; i++) {
185190
for (const auto& prompt : prompts) {
186191
std::string formatted_prompt;
187192
formatted_prompt = get_formatted_prompt(
188193
prompt, FLAGS_system_prompt, decoder_model_version.get());
189-
runner.generate(
190-
formatted_prompt.c_str(),
191-
use_tokenized_prompt,
192-
FLAGS_seq_len,
193-
callback);
194+
runner.generate_from_prompt_or_file(
195+
formatted_prompt.c_str(), use_tokenized_prompt, config, callback);
194196
}
195197
}
196198
}

examples/qualcomm/oss_scripts/llama/runner/runner.cpp

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
// A llama 3.2 runner that includes preprocessing and post processing
1010
// logic. The module takes in a string as input and emits a string as output.
1111

12+
#include <executorch/examples/models/llama/runner/runner.h>
1213
#include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
1314
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/client_mem.h>
1415
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h>
@@ -58,7 +59,7 @@ void print_performance_report(
5859
outfile << num_tok;
5960
outfile.close();
6061
} else {
61-
ET_CHECK_MSG(false, "Error saving the inference speed file");
62+
ET_LOG(Error, "Error saving the inference speed file");
6263
}
6364
}
6465

@@ -83,13 +84,6 @@ void save_logits(
8384

8485
} // namespace
8586

86-
std::unique_ptr<::tokenizers::Tokenizer> load_llama_tokenizer(
87-
const std::string& tokenizer_path,
88-
Version version) {
89-
auto special_tokens = get_special_tokens(version);
90-
return llm::load_tokenizer(tokenizer_path, std::move(special_tokens));
91-
}
92-
9387
template <typename T>
9488
Runner<T>::Runner(
9589
std::unique_ptr<executorch::extension::Module> module,
@@ -181,7 +175,8 @@ Error Runner<T>::load() {
181175
eos_ids->insert(tokenizer_->encode("<|eot|>", 0, 0).get()[0]);
182176
eos_ids->insert(tokenizer_->encode("<|end_of_text|>", 0, 0).get()[0]);
183177
} else {
184-
tokenizer_ = load_llama_tokenizer(tokenizer_path_, Version::Default);
178+
tokenizer_ =
179+
example::load_llama_tokenizer(tokenizer_path_, Version::Default);
185180
if (tokenizer_ == nullptr) {
186181
ET_LOG(
187182
Error, "Failed to load tokenizer with %s", tokenizer_path_.c_str());
@@ -323,13 +318,32 @@ Error Runner<T>::load() {
323318

324319
template <typename T>
325320
Error Runner<T>::generate(
321+
const std::string& prompt,
322+
const llm::GenerationConfig& config,
323+
std::function<void(const std::string&)> token_callback,
324+
std::function<void(const Stats&)> stats_callback) {
325+
return generate_from_pos(prompt, 0, config, token_callback, stats_callback);
326+
}
327+
328+
template <typename T>
329+
Error Runner<T>::generate_from_pos(
330+
const std::string& prompt,
331+
int64_t start_pos,
332+
const llm::GenerationConfig& config,
333+
std::function<void(const std::string&)> token_callback,
334+
std::function<void(const Stats&)> stats_callback) {
335+
// TODO: currently only support start_pos == 0
336+
return generate_from_prompt_or_file(
337+
prompt, false, config, token_callback, stats_callback);
338+
}
339+
340+
template <typename T>
341+
Error Runner<T>::generate_from_prompt_or_file(
326342
const std::string& prompt,
327343
bool tokenized_prompt,
328-
int32_t seq_len,
344+
const llm::GenerationConfig& config,
329345
std::function<void(const std::string&)> token_callback,
330-
std::function<void(const Stats&)> stats_callback,
331-
bool echo,
332-
bool warming) {
346+
std::function<void(const Stats&)> stats_callback) {
333347
ET_CHECK_MSG(!prompt.empty(), "prompt cannot be null");
334348
if (!is_loaded()) {
335349
stats_.model_load_start_ms = time_in_ms();
@@ -338,6 +352,7 @@ Error Runner<T>::generate(
338352
}
339353
stats_.inference_start_ms = time_in_ms();
340354

355+
int32_t seq_len = config.seq_len;
341356
seq_len = (seq_len > 0 && seq_len <= context_len_) ? seq_len : context_len_;
342357
int32_t n_bos = (cur_pos_ == 0) ? 1 : 0;
343358

@@ -376,7 +391,7 @@ Error Runner<T>::generate(
376391
"sequence length exceeded - please increase the seq_len value");
377392

378393
// Prompt Processor first
379-
if (token_callback) {
394+
if (token_callback && config.echo) {
380395
token_callback(prompt);
381396
}
382397
bool dump_logits = dump_logits_path_.empty() ? false : true;

examples/qualcomm/oss_scripts/llama/runner/runner.h

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h>
2222
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h>
2323
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/token_generator.h>
24+
#include <executorch/extension/llm/runner/irunner.h>
2425
#include <executorch/extension/llm/runner/stats.h>
2526
#include <executorch/extension/module/module.h>
2627
#include <pytorch/tokenizers/tokenizer.h>
@@ -41,7 +42,7 @@ enum KvBitWidth {
4142
};
4243

4344
template <typename T>
44-
class Runner {
45+
class Runner : public executorch::extension::llm::IRunner {
4546
public:
4647
explicit Runner(
4748
std::unique_ptr<executorch::extension::Module> module,
@@ -51,25 +52,36 @@ class Runner {
5152
const std::string& performance_output_path,
5253
const std::string& dump_logits_path,
5354
const float temperature = 0.8f,
54-
const int eval_mode = EvalMode::kKVCached,
55+
const int eval_mode = EvalMode::kHybrid,
5556
const std::string& kv_updater = "SmartMask",
5657
const int ngram = 0,
5758
const int window = 0,
5859
const int gcap = 0,
5960
std::unique_ptr<tokenizers::Tokenizer> tokenizer = nullptr);
6061

61-
bool is_loaded() const;
62-
executorch::runtime::Error load();
62+
bool is_loaded() const override;
63+
executorch::runtime::Error load() override;
6364
// TODO: Support echo and warming
6465
executorch::runtime::Error generate(
66+
const std::string& prompt,
67+
const executorch::extension::llm::GenerationConfig& config,
68+
std::function<void(const std::string&)> token_callback = {},
69+
std::function<void(const executorch::llm::Stats&)> stats_callback = {})
70+
override;
71+
executorch::runtime::Error generate_from_pos(
72+
const std::string& prompt,
73+
int64_t start_pos,
74+
const executorch::extension::llm::GenerationConfig& config,
75+
std::function<void(const std::string&)> token_callback = {},
76+
std::function<void(const executorch::llm::Stats&)> stats_callback = {})
77+
override;
78+
executorch::runtime::Error generate_from_prompt_or_file(
6579
const std::string& prompt,
6680
bool tokenized_prompt,
67-
int32_t seq_len,
81+
const executorch::extension::llm::GenerationConfig& config,
6882
std::function<void(const std::string&)> token_callback = {},
69-
std::function<void(const executorch::llm::Stats&)> stats_callback = {},
70-
bool echo = true,
71-
bool warming = false);
72-
void stop() {};
83+
std::function<void(const executorch::llm::Stats&)> stats_callback = {});
84+
void stop() override {};
7385
executorch::runtime::Result<DecoderModelVersion> get_decoder_model_version();
7486

7587
private:

examples/qualcomm/oss_scripts/llama/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def define_common_targets():
2929
exported_deps = [
3030
"//executorch/extension/module:module",
3131
"//executorch/extension/llm/sampler:sampler",
32+
"//executorch/examples/models/llama/runner:runner",
3233
"//executorch/examples/models/llama/tokenizer:tiktoken",
3334
"//executorch/extension/evalue_util:print_evalue",
3435
"//executorch/backends/qualcomm/runtime:runtime",

extension/android/CMakeLists.txt

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,35 @@ if(EXECUTORCH_BUILD_LLAMA_JNI)
179179
${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llama/runner
180180
)
181181

182+
target_sources(
183+
executorch_jni
184+
PRIVATE ${EXECUTORCH_ROOT}/extension/llm/runner/llm_runner_helper.cpp
185+
)
186+
187+
target_include_directories(
188+
executorch_jni PRIVATE ${EXECUTORCH_ROOT}/extension/llm/runner
189+
)
190+
191+
if(QNN_SDK_ROOT)
192+
target_sources(
193+
executorch_jni
194+
PRIVATE
195+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/runner.cpp
196+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/decoder_runner.cpp
197+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp
198+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp
199+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp
200+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/rpc_mem.cpp
201+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp
202+
)
203+
204+
target_include_directories(
205+
executorch_jni
206+
PRIVATE ${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner
207+
)
208+
target_compile_definitions(executorch_jni PRIVATE EXECUTORCH_BUILD_QNN=1)
209+
endif()
210+
182211
if(NEURON_BUFFER_ALLOCATOR_LIB)
183212
target_sources(
184213
executorch_jni

extension/android/jni/jni_layer_llama.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include <executorch/examples/models/llama/runner/runner.h>
1717
#include <executorch/examples/models/llava/runner/llava_runner.h>
18+
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/runner.h>
1819
#include <executorch/extension/llm/runner/image.h>
1920
#include <executorch/extension/llm/runner/irunner.h>
2021
#include <executorch/runtime/platform/log.h>
@@ -29,6 +30,10 @@
2930
#include <fbjni/ByteBuffer.h>
3031
#include <fbjni/fbjni.h>
3132

33+
#if defined(EXECUTORCH_BUILD_QNN)
34+
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/runner.h>
35+
#endif
36+
3237
#if defined(EXECUTORCH_BUILD_MEDIATEK)
3338
#include <executorch/examples/mediatek/executor_runner/mtk_llama_runner.h>
3439
#endif
@@ -124,6 +129,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
124129
constexpr static int MODEL_TYPE_CATEGORY_LLM = 1;
125130
constexpr static int MODEL_TYPE_CATEGORY_MULTIMODAL = 2;
126131
constexpr static int MODEL_TYPE_MEDIATEK_LLAMA = 3;
132+
constexpr static int MODEL_TYPE_QNN_LLAMA = 4;
127133

128134
static facebook::jni::local_ref<jhybriddata> initHybrid(
129135
facebook::jni::alias_ref<jclass>,
@@ -174,6 +180,22 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
174180
model_path->toStdString(),
175181
tokenizer_path->toStdString(),
176182
data_path_str);
183+
#if defined(EXECUTORCH_BUILD_QNN)
184+
} else if (model_type_category == MODEL_TYPE_QNN_LLAMA) {
185+
std::unique_ptr<executorch::extension::Module> module = std::make_unique<
186+
executorch::extension::Module>(
187+
model_path->toStdString().c_str(),
188+
executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors);
189+
std::string decoder_model = "llama3"; // use llama3 for now
190+
runner_ = std::make_unique<example::Runner<uint16_t>>( // QNN runner
191+
std::move(module),
192+
decoder_model.c_str(),
193+
model_path->toStdString().c_str(),
194+
tokenizer_path->toStdString().c_str(),
195+
data_path->toStdString().c_str(),
196+
"");
197+
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
198+
#endif
177199
#if defined(EXECUTORCH_BUILD_MEDIATEK)
178200
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
179201
runner_ = std::make_unique<MTKLlamaRunner>(
@@ -318,6 +340,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
318340
[callback](std::string result) { callback->onResult(result); },
319341
[callback](const llm::Stats& stats) { callback->onStats(stats); }));
320342
}
343+
return static_cast<jint>(executorch::runtime::Error::InvalidArgument);
321344
}
322345

323346
void stop() {

0 commit comments

Comments
 (0)