Skip to content

Commit 5b377a2

Browse files
rohansjoshifacebook-github-bot
authored andcommitted
QNN Llama Runner implement IRunner (#13171)
Summary: This PR makes the Runner for running Qualcomm LlamaModels implement the IRunner interface Using this, enable running static Llama models inside LlamaDemo Android app Switched default eval mode to hybrid everywhere Reviewed By: cccclai Differential Revision: D79759817
1 parent bba378c commit 5b377a2

File tree

10 files changed

+125
-37
lines changed

10 files changed

+125
-37
lines changed

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ public class ModelUtils {
2121
// MediaTek
2222
static final int MEDIATEK_TEXT_MODEL = 3;
2323

24+
// QNN static llama
25+
static final int QNN_TEXT_MODEL = 4;
26+
2427
public static int getModelCategory(ModelType modelType, BackendType backendType) {
2528
if (backendType.equals(BackendType.XNNPACK)) {
2629
switch (modelType) {
@@ -35,6 +38,8 @@ public static int getModelCategory(ModelType modelType, BackendType backendType)
3538
}
3639
} else if (backendType.equals(BackendType.MEDIATEK)) {
3740
return MEDIATEK_TEXT_MODEL;
41+
} else if (backendType.equals(BackendType.QUALCOMM)) {
42+
return QNN_TEXT_MODEL;
3843
}
3944

4045
return TEXT_MODEL; // default

examples/models/llama/runner/runner.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <string>
1919
#include <unordered_map>
2020

21+
#include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
2122
#include <executorch/extension/llm/runner/irunner.h>
2223
#include <executorch/extension/llm/runner/text_llm_runner.h>
2324
#include <pytorch/tokenizers/tokenizer.h>
@@ -33,6 +34,7 @@ std::unique_ptr<llm::TextLLMRunner> create_llama_runner(
3334
float temperature = -1.0f);
3435

3536
std::unique_ptr<tokenizers::Tokenizer> load_llama_tokenizer(
36-
const std::string& tokenizer_path);
37+
const std::string& tokenizer_path,
38+
Version version = Version::Default);
3739

3840
} // namespace example

examples/qualcomm/oss_scripts/llama/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ list(
4242
${CMAKE_CURRENT_LIST_DIR}/runner/rpc_mem.h
4343
${CMAKE_CURRENT_LIST_DIR}/runner/kv_manager.cpp
4444
${CMAKE_CURRENT_LIST_DIR}/runner/kv_manager.h
45+
${EXECUTORCH_SOURCE_DIR}/examples/models/llama/runner/runner.cpp
46+
${EXECUTORCH_SOURCE_DIR}/examples/models/llama/runner/runner.h
4547
)
4648

4749
list(APPEND _llama_runner__srcs)

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1040,7 +1040,7 @@ def _build_parser():
10401040
parser.add_argument(
10411041
"--model_mode",
10421042
help="Export and inference kv mode, hybrid mode, or lookahead decoding mode",
1043-
default="kv",
1043+
default="hybrid",
10441044
choices=["kv", "hybrid", "lookahead"],
10451045
type=str,
10461046
)

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include <executorch/backends/qualcomm/runtime/QnnExecuTorch.h>
1818
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/runner.h>
19+
#include <executorch/extension/llm/runner/irunner.h>
1920
#include <executorch/runtime/platform/log.h>
2021
#include <gflags/gflags.h>
2122
#include <fstream>
@@ -61,7 +62,7 @@ DEFINE_int32(
6162
"Total number of tokens to generate (prompt + output).");
6263
DEFINE_int32(
6364
eval_mode,
64-
0,
65+
1,
6566
"0: TokenGenerator(kv) / 1: HybridMode (prefill+kv) / 2: Lookahead Decoding");
6667
DEFINE_string(
6768
kv_updater,
@@ -163,25 +164,26 @@ void start_runner(
163164
buf.push_back(c);
164165
}
165166
};
166-
167+
executorch::extension::llm::GenerationConfig config{
168+
true,
169+
-1,
170+
false,
171+
FLAGS_seq_len,
172+
static_cast<float>(FLAGS_temperature),
173+
0,
174+
0};
167175
if (use_tokenized_prompt) {
168-
runner.generate(
169-
FLAGS_tokenized_prompt.c_str(),
170-
use_tokenized_prompt,
171-
FLAGS_seq_len,
172-
callback);
176+
runner.generate_from_prompt_or_file(
177+
FLAGS_tokenized_prompt.c_str(), use_tokenized_prompt, config, callback);
173178
} else {
174179
// generate tokens & store inference output
175180
for (int i = 0; i < FLAGS_num_iters; i++) {
176181
for (const auto& prompt : prompts) {
177182
std::string formatted_prompt;
178183
formatted_prompt = get_formatted_prompt(
179184
prompt, FLAGS_system_prompt, decoder_model_version.get());
180-
runner.generate(
181-
formatted_prompt.c_str(),
182-
use_tokenized_prompt,
183-
FLAGS_seq_len,
184-
callback);
185+
runner.generate_from_prompt_or_file(
186+
formatted_prompt.c_str(), use_tokenized_prompt, config, callback);
185187
}
186188
}
187189
}

examples/qualcomm/oss_scripts/llama/runner/runner.cpp

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
// A llama 3.2 runner that includes preprocessing and post processing
1010
// logic. The module takes in a string as input and emits a string as output.
1111

12+
#include <executorch/examples/models/llama/runner/runner.h>
1213
#include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
1314
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/client_mem.h>
1415
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h>
@@ -58,7 +59,7 @@ void print_performance_report(
5859
outfile << num_tok;
5960
outfile.close();
6061
} else {
61-
ET_CHECK_MSG(false, "Error saving the inference speed file");
62+
ET_LOG(Error, "Error saving the inference speed file");
6263
}
6364
}
6465

@@ -83,13 +84,6 @@ void save_logits(
8384

8485
} // namespace
8586

86-
std::unique_ptr<::tokenizers::Tokenizer> load_llama_tokenizer(
87-
const std::string& tokenizer_path,
88-
Version version) {
89-
auto special_tokens = get_special_tokens(version);
90-
return llm::load_tokenizer(tokenizer_path, std::move(special_tokens));
91-
}
92-
9387
template <typename T>
9488
Runner<T>::Runner(
9589
std::unique_ptr<executorch::extension::Module> module,
@@ -179,7 +173,8 @@ Error Runner<T>::load() {
179173
eos_ids->insert(tokenizer_->encode("<|eot|>", 0, 0).get()[0]);
180174
eos_ids->insert(tokenizer_->encode("<|end_of_text|>", 0, 0).get()[0]);
181175
} else {
182-
tokenizer_ = load_llama_tokenizer(tokenizer_path_, Version::Default);
176+
tokenizer_ =
177+
example::load_llama_tokenizer(tokenizer_path_, Version::Default);
183178
if (tokenizer_ == nullptr) {
184179
ET_LOG(
185180
Error, "Failed to load tokenizer with %s", tokenizer_path_.c_str());
@@ -321,13 +316,30 @@ Error Runner<T>::load() {
321316

322317
template <typename T>
323318
Error Runner<T>::generate(
319+
const std::string& prompt,
320+
const llm::GenerationConfig& config,
321+
std::function<void(const std::string&)> token_callback,
322+
std::function<void(const Stats&)> stats_callback) {
323+
return generate_from_pos(prompt, 0, config, token_callback, stats_callback);
324+
}
325+
326+
Error Runner::generate_from_pos(
327+
const std::string& prompt,
328+
int64_t start_pos,
329+
const llm::GenerationConfig& config,
330+
std::function<void(const std::string&)> token_callback,
331+
std::function<void(const Stats&)> stats_callback) {
332+
// TODO: currently only support start_pos == 0
333+
return generate_from_prompt_or_file(
334+
prompt, false, config, token_callback, stats_callback);
335+
}
336+
337+
Error Runner::generate_from_prompt_or_file(
324338
const std::string& prompt,
325339
bool tokenized_prompt,
326-
int32_t seq_len,
340+
const llm::GenerationConfig& config,
327341
std::function<void(const std::string&)> token_callback,
328-
std::function<void(const Stats&)> stats_callback,
329-
bool echo,
330-
bool warming) {
342+
std::function<void(const Stats&)> stats_callback) {
331343
ET_CHECK_MSG(!prompt.empty(), "prompt cannot be null");
332344
if (!is_loaded()) {
333345
stats_.model_load_start_ms = time_in_ms();
@@ -336,6 +348,7 @@ Error Runner<T>::generate(
336348
}
337349
stats_.inference_start_ms = time_in_ms();
338350

351+
int32_t seq_len = config.seq_len;
339352
seq_len = (seq_len > 0 && seq_len <= context_len_) ? seq_len : context_len_;
340353
int32_t n_bos = (cur_pos_ == 0) ? 1 : 0;
341354

@@ -374,7 +387,7 @@ Error Runner<T>::generate(
374387
"sequence length exceeded - please increase the seq_len value");
375388

376389
// Prompt Processor first
377-
if (token_callback) {
390+
if (token_callback && config.echo) {
378391
token_callback(prompt);
379392
}
380393
bool dump_logits = dump_logits_path_.empty() ? false : true;

examples/qualcomm/oss_scripts/llama/runner/runner.h

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h>
2222
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h>
2323
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/token_generator.h>
24+
#include <executorch/extension/llm/runner/irunner.h>
2425
#include <executorch/extension/llm/runner/stats.h>
2526
#include <executorch/extension/module/module.h>
2627
#include <pytorch/tokenizers/tokenizer.h>
@@ -40,7 +41,7 @@ enum KvBitWidth {
4041
};
4142

4243
template <typename T>
43-
class Runner {
44+
class Runner : public executorch::extension::llm::IRunner {
4445
public:
4546
explicit Runner(
4647
std::unique_ptr<executorch::extension::Module> module,
@@ -50,25 +51,36 @@ class Runner {
5051
const std::string& performance_output_path,
5152
const std::string& dump_logits_path,
5253
const float temperature = 0.8f,
53-
const int eval_mode = EvalMode::kKVCached,
54+
const int eval_mode = EvalMode::kHybrid,
5455
const std::string& kv_updater = "SmartMask",
5556
const int ngram = 0,
5657
const int window = 0,
5758
const int gcap = 0,
5859
std::unique_ptr<tokenizers::Tokenizer> tokenizer = nullptr);
5960

60-
bool is_loaded() const;
61-
executorch::runtime::Error load();
61+
bool is_loaded() const override;
62+
executorch::runtime::Error load() override;
6263
// TODO: Support echo and warming
6364
executorch::runtime::Error generate(
65+
const std::string& prompt,
66+
const executorch::extension::llm::GenerationConfig& config,
67+
std::function<void(const std::string&)> token_callback = {},
68+
std::function<void(const executorch::llm::Stats&)> stats_callback = {})
69+
override;
70+
executorch::runtime::Error generate_from_pos(
71+
const std::string& prompt,
72+
int64_t start_pos,
73+
const executorch::extension::llm::GenerationConfig& config,
74+
std::function<void(const std::string&)> token_callback = {},
75+
std::function<void(const executorch::llm::Stats&)> stats_callback = {})
76+
override;
77+
executorch::runtime::Error generate_from_prompt_or_file(
6478
const std::string& prompt,
6579
bool tokenized_prompt,
66-
int32_t seq_len,
80+
const executorch::extension::llm::GenerationConfig& config,
6781
std::function<void(const std::string&)> token_callback = {},
68-
std::function<void(const executorch::llm::Stats&)> stats_callback = {},
69-
bool echo = true,
70-
bool warming = false);
71-
void stop() {};
82+
std::function<void(const executorch::llm::Stats&)> stats_callback = {});
83+
void stop() override {};
7284
executorch::runtime::Result<DecoderModelVersion> get_decoder_model_version();
7385

7486
private:

examples/qualcomm/oss_scripts/llama/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def define_common_targets():
2929
exported_deps = [
3030
"//executorch/extension/module:module",
3131
"//executorch/extension/llm/sampler:sampler",
32+
"//executorch/examples/models/llama/runner:runner",
3233
"//executorch/examples/models/llama/tokenizer:tiktoken",
3334
"//executorch/extension/evalue_util:print_evalue",
3435
"//executorch/backends/qualcomm/runtime:runtime",

extension/android/CMakeLists.txt

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,35 @@ if(EXECUTORCH_BUILD_LLAMA_JNI)
179179
${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llama/runner
180180
)
181181

182+
target_sources(
183+
executorch_jni PRIVATE
184+
${EXECUTORCH_ROOT}/extension/llm/runner/llm_runner_helper.cpp
185+
)
186+
187+
target_include_directories(
188+
executorch_jni PRIVATE
189+
${EXECUTORCH_ROOT}/extension/llm/runner
190+
)
191+
192+
if(QNN_SDK_ROOT)
193+
target_sources(
194+
executorch_jni PRIVATE
195+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/runner.cpp
196+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/decoder_runner.cpp
197+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp
198+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp
199+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp
200+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/rpc_mem.cpp
201+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp
202+
)
203+
204+
target_include_directories(
205+
executorch_jni PRIVATE
206+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner
207+
)
208+
target_compile_definitions(executorch_jni PRIVATE EXECUTORCH_BUILD_QNN=1)
209+
endif()
210+
182211
if(NEURON_BUFFER_ALLOCATOR_LIB)
183212
target_sources(
184213
executorch_jni

extension/android/jni/jni_layer_llama.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
#include <fbjni/ByteBuffer.h>
3030
#include <fbjni/fbjni.h>
3131

32+
#if defined(EXECUTORCH_BUILD_QNN)
33+
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/runner.h>
34+
#endif
35+
3236
#if defined(EXECUTORCH_BUILD_MEDIATEK)
3337
#include <executorch/examples/mediatek/executor_runner/mtk_llama_runner.h>
3438
#endif
@@ -124,6 +128,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
124128
constexpr static int MODEL_TYPE_CATEGORY_LLM = 1;
125129
constexpr static int MODEL_TYPE_CATEGORY_MULTIMODAL = 2;
126130
constexpr static int MODEL_TYPE_MEDIATEK_LLAMA = 3;
131+
constexpr static int MODEL_TYPE_QNN_LLAMA = 4;
127132

128133
static facebook::jni::local_ref<jhybriddata> initHybrid(
129134
facebook::jni::alias_ref<jclass>,
@@ -174,6 +179,22 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
174179
model_path->toStdString(),
175180
tokenizer_path->toStdString(),
176181
data_path_str);
182+
#if defined(EXECUTORCH_BUILD_QNN)
183+
} else if (model_type_category == MODEL_TYPE_QNN_LLAMA) {
184+
std::unique_ptr<executorch::extension::Module> module =
185+
std::make_unique<executorch::extension::Module>(
186+
FLAGS_model_path.c_str(),
187+
executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors);
188+
std::string decoder_model = "llama3"; // use llama3 for now
189+
runner_ = std::make_unique<example::Runner<uint16_t>>( // QNN runner
190+
std::move(module),
191+
decoder_model.c_str(),
192+
model_path->toStdString().c_str(),
193+
tokenizer_path->toStdString().c_str(),
194+
data_path->toStdString().c_str(),
195+
"");
196+
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
197+
#endif
177198
#if defined(EXECUTORCH_BUILD_MEDIATEK)
178199
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
179200
runner_ = std::make_unique<MTKLlamaRunner>(
@@ -318,6 +339,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
318339
[callback](std::string result) { callback->onResult(result); },
319340
[callback](const llm::Stats& stats) { callback->onStats(stats); }));
320341
}
342+
return static_cast<jint>(executorch::runtime::Error::InvalidArgument);
321343
}
322344

323345
void stop() {

0 commit comments

Comments
 (0)