Skip to content

Commit 991dd98

Browse files
rohansjoshifacebook-github-bot
authored andcommitted
QNN Llama Runner implement IRunner (#13171)
Summary: This PR makes the Runner for running Qualcomm LlamaModels implement the IRunner interface Using this, enable running static Llama models inside LlamaDemo Android app Switched default eval mode to hybrid everywhere Reviewed By: cccclai Differential Revision: D79759817
1 parent 4ce7078 commit 991dd98

File tree

10 files changed

+102
-37
lines changed

10 files changed

+102
-37
lines changed

examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ public class ModelUtils {
2121
// MediaTek
2222
static final int MEDIATEK_TEXT_MODEL = 3;
2323

24+
// QNN static llama
25+
static final int QNN_TEXT_MODEL = 4;
26+
2427
public static int getModelCategory(ModelType modelType, BackendType backendType) {
2528
if (backendType.equals(BackendType.XNNPACK)) {
2629
switch (modelType) {
@@ -35,6 +38,8 @@ public static int getModelCategory(ModelType modelType, BackendType backendType)
3538
}
3639
} else if (backendType.equals(BackendType.MEDIATEK)) {
3740
return MEDIATEK_TEXT_MODEL;
41+
} else if (backendType.equals(BackendType.QUALCOMM)) {
42+
return QNN_TEXT_MODEL;
3843
}
3944

4045
return TEXT_MODEL; // default

examples/models/llama/runner/runner.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <string>
1919
#include <unordered_map>
2020

21+
#include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
2122
#include <executorch/extension/llm/runner/irunner.h>
2223
#include <executorch/extension/llm/runner/text_llm_runner.h>
2324
#include <pytorch/tokenizers/tokenizer.h>
@@ -33,6 +34,7 @@ std::unique_ptr<llm::TextLLMRunner> create_llama_runner(
3334
float temperature = -1.0f);
3435

3536
std::unique_ptr<tokenizers::Tokenizer> load_llama_tokenizer(
36-
const std::string& tokenizer_path);
37+
const std::string& tokenizer_path,
38+
Version version = Version::Default);
3739

3840
} // namespace example

examples/qualcomm/oss_scripts/llama/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ target_link_libraries(
6565
extension_tensor
6666
gflags
6767
custom_ops
68+
llama_runner
6869
quantized_ops_lib
6970
quantized_kernels
7071
tokenizers::tokenizers

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1038,7 +1038,7 @@ def _build_parser():
10381038
parser.add_argument(
10391039
"--model_mode",
10401040
help="Export and inference kv mode, hybrid mode, or lookahead decoding mode",
1041-
default="kv",
1041+
default="hybrid",
10421042
choices=["kv", "hybrid", "lookahead"],
10431043
type=str,
10441044
)

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include <executorch/backends/qualcomm/runtime/QnnExecuTorch.h>
1818
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/runner.h>
19+
#include <executorch/extension/llm/runner/irunner.h>
1920
#include <executorch/runtime/platform/log.h>
2021
#include <gflags/gflags.h>
2122
#include <fstream>
@@ -61,7 +62,7 @@ DEFINE_int32(
6162
"Total number of tokens to generate (prompt + output).");
6263
DEFINE_int32(
6364
eval_mode,
64-
0,
65+
1,
6566
"0: TokenGenerator(kv) / 1: HybridMode (prefill+kv) / 2: Lookahead Decoding");
6667
DEFINE_string(
6768
kv_updater,
@@ -161,25 +162,26 @@ int main(int argc, char** argv) {
161162
buf.push_back(c);
162163
}
163164
};
164-
165+
executorch::extension::llm::GenerationConfig config{
166+
true,
167+
-1,
168+
false,
169+
FLAGS_seq_len,
170+
static_cast<float>(FLAGS_temperature),
171+
0,
172+
0};
165173
if (use_tokenized_prompt) {
166-
runner.generate(
167-
FLAGS_tokenized_prompt.c_str(),
168-
use_tokenized_prompt,
169-
FLAGS_seq_len,
170-
callback);
174+
runner.generate_from_prompt_or_file(
175+
FLAGS_tokenized_prompt.c_str(), use_tokenized_prompt, config, callback);
171176
} else {
172177
// generate tokens & store inference output
173178
for (int i = 0; i < FLAGS_num_iters; i++) {
174179
for (const auto& prompt : prompts) {
175180
std::string formatted_prompt;
176181
formatted_prompt = get_formatted_prompt(
177182
prompt, FLAGS_system_prompt, decoder_model_version.get());
178-
runner.generate(
179-
formatted_prompt.c_str(),
180-
use_tokenized_prompt,
181-
FLAGS_seq_len,
182-
callback);
183+
runner.generate_from_prompt_or_file(
184+
formatted_prompt.c_str(), use_tokenized_prompt, config, callback);
183185
}
184186
}
185187
}

examples/qualcomm/oss_scripts/llama/runner/runner.cpp

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
// A llama 3.2 runner that includes preprocessing and post processing
1010
// logic. The module takes in a string as input and emits a string as output.
1111

12+
#include <executorch/examples/models/llama/runner/runner.h>
1213
#include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
1314
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/client_mem.h>
1415
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h>
@@ -59,7 +60,7 @@ void print_performance_report(
5960
outfile << num_tok;
6061
outfile.close();
6162
} else {
62-
ET_CHECK_MSG(false, "Error saving the inference speed file");
63+
ET_LOG(Error, "Error saving the inference speed file");
6364
}
6465
}
6566

@@ -84,13 +85,6 @@ void save_logits(
8485

8586
} // namespace
8687

87-
std::unique_ptr<::tokenizers::Tokenizer> load_llama_tokenizer(
88-
const std::string& tokenizer_path,
89-
Version version) {
90-
auto special_tokens = get_special_tokens(version);
91-
return llm::load_tokenizer(tokenizer_path, std::move(special_tokens));
92-
}
93-
9488
Runner::Runner(
9589
const std::string& decoder_model_version,
9690
const std::string& model_path,
@@ -175,7 +169,8 @@ Error Runner::load() {
175169
eos_ids->insert(tokenizer_->encode("<|eot|>", 0, 0).get()[0]);
176170
eos_ids->insert(tokenizer_->encode("<|end_of_text|>", 0, 0).get()[0]);
177171
} else {
178-
tokenizer_ = load_llama_tokenizer(tokenizer_path_, Version::Default);
172+
tokenizer_ =
173+
example::load_llama_tokenizer(tokenizer_path_, Version::Default);
179174
if (tokenizer_ == nullptr) {
180175
ET_LOG(
181176
Error, "Failed to load tokenizer with %s", tokenizer_path_.c_str());
@@ -313,13 +308,30 @@ Error Runner::load() {
313308
}
314309

315310
Error Runner::generate(
311+
const std::string& prompt,
312+
const llm::GenerationConfig& config,
313+
std::function<void(const std::string&)> token_callback,
314+
std::function<void(const Stats&)> stats_callback) {
315+
return generate_from_pos(prompt, 0, config, token_callback, stats_callback);
316+
}
317+
318+
Error Runner::generate_from_pos(
319+
const std::string& prompt,
320+
int64_t start_pos,
321+
const llm::GenerationConfig& config,
322+
std::function<void(const std::string&)> token_callback,
323+
std::function<void(const Stats&)> stats_callback) {
324+
// TODO: currently only support start_pos == 0
325+
return generate_from_prompt_or_file(
326+
prompt, false, config, token_callback, stats_callback);
327+
}
328+
329+
Error Runner::generate_from_prompt_or_file(
316330
const std::string& prompt,
317331
bool tokenized_prompt,
318-
int32_t seq_len,
332+
const llm::GenerationConfig& config,
319333
std::function<void(const std::string&)> token_callback,
320-
std::function<void(const Stats&)> stats_callback,
321-
bool echo,
322-
bool warming) {
334+
std::function<void(const Stats&)> stats_callback) {
323335
ET_CHECK_MSG(!prompt.empty(), "prompt cannot be null");
324336
if (!is_loaded()) {
325337
stats_.model_load_start_ms = time_in_ms();
@@ -328,6 +340,7 @@ Error Runner::generate(
328340
}
329341
stats_.inference_start_ms = time_in_ms();
330342

343+
int32_t seq_len = config.seq_len;
331344
seq_len = (seq_len > 0 && seq_len <= context_len_) ? seq_len : context_len_;
332345
int32_t n_bos = (cur_pos_ == 0) ? 1 : 0;
333346

@@ -366,7 +379,7 @@ Error Runner::generate(
366379
"sequence length exceeded - please increase the seq_len value");
367380

368381
// Prompt Processor first
369-
if (token_callback) {
382+
if (token_callback && config.echo) {
370383
token_callback(prompt);
371384
}
372385
bool dump_logits = dump_logits_path_.empty() ? false : true;

examples/qualcomm/oss_scripts/llama/runner/runner.h

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h>
2222
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h>
2323
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/token_generator.h>
24+
#include <executorch/extension/llm/runner/irunner.h>
2425
#include <executorch/extension/llm/runner/stats.h>
2526
#include <executorch/extension/module/module.h>
2627
#include <pytorch/tokenizers/tokenizer.h>
@@ -32,7 +33,7 @@ enum DecoderModelVersion {
3233
kLlama3,
3334
kQwen2_5,
3435
};
35-
class Runner {
36+
class Runner : public executorch::extension::llm::IRunner {
3637
public:
3738
explicit Runner(
3839
const std::string& decoder_model,
@@ -41,25 +42,36 @@ class Runner {
4142
const std::string& performance_output_path,
4243
const std::string& dump_logits_path,
4344
const float temperature = 0.8f,
44-
const int eval_mode = EvalMode::kKVCached,
45+
const int eval_mode = EvalMode::kHybrid,
4546
const std::string& kv_updater = "SmartMask",
4647
const int ngram = 0,
4748
const int window = 0,
4849
const int gcap = 0,
4950
std::unique_ptr<tokenizers::Tokenizer> tokenizer = nullptr);
5051

51-
bool is_loaded() const;
52-
executorch::runtime::Error load();
52+
bool is_loaded() const override;
53+
executorch::runtime::Error load() override;
5354
// TODO: Support echo and warming
5455
executorch::runtime::Error generate(
56+
const std::string& prompt,
57+
const executorch::extension::llm::GenerationConfig& config,
58+
std::function<void(const std::string&)> token_callback = {},
59+
std::function<void(const executorch::llm::Stats&)> stats_callback = {})
60+
override;
61+
executorch::runtime::Error generate_from_pos(
62+
const std::string& prompt,
63+
int64_t start_pos,
64+
const executorch::extension::llm::GenerationConfig& config,
65+
std::function<void(const std::string&)> token_callback = {},
66+
std::function<void(const executorch::llm::Stats&)> stats_callback = {})
67+
override;
68+
executorch::runtime::Error generate_from_prompt_or_file(
5569
const std::string& prompt,
5670
bool tokenized_prompt,
57-
int32_t seq_len,
71+
const executorch::extension::llm::GenerationConfig& config,
5872
std::function<void(const std::string&)> token_callback = {},
59-
std::function<void(const executorch::llm::Stats&)> stats_callback = {},
60-
bool echo = true,
61-
bool warming = false);
62-
void stop() {};
73+
std::function<void(const executorch::llm::Stats&)> stats_callback = {});
74+
void stop() override {};
6375
executorch::runtime::Result<DecoderModelVersion> get_decoder_model_version();
6476

6577
private:

examples/qualcomm/oss_scripts/llama/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def define_common_targets():
2929
exported_deps = [
3030
"//executorch/extension/module:module",
3131
"//executorch/extension/llm/sampler:sampler",
32+
"//executorch/examples/models/llama/runner:runner",
3233
"//executorch/examples/models/llama/tokenizer:tiktoken",
3334
"//executorch/extension/evalue_util:print_evalue",
3435
"//executorch/backends/qualcomm/runtime:runtime",

extension/android/CMakeLists.txt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,24 @@ if(EXECUTORCH_BUILD_LLAMA_JNI)
166166
${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llama/runner
167167
)
168168

169+
if(QNN_SDK_ROOT)
170+
target_sources(
171+
executorch_jni PRIVATE
172+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/runner.cpp
173+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/decoder_runner.cpp
174+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp
175+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp
176+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp
177+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/rpc_mem.cpp
178+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp
179+
)
180+
181+
target_include_directories(
182+
executorch_jni PRIVATE
183+
${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner
184+
)
185+
endif()
186+
169187
if(NEURON_BUFFER_ALLOCATOR_LIB)
170188
target_sources(
171189
executorch_jni PRIVATE

extension/android/jni/jni_layer_llama.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include <executorch/examples/models/llama/runner/runner.h>
1717
#include <executorch/examples/models/llava/runner/llava_runner.h>
18+
#include <executorch/examples/qualcomm/oss_scripts/llama/runner/runner.h>
1819
#include <executorch/extension/llm/runner/image.h>
1920
#include <executorch/extension/llm/runner/irunner.h>
2021
#include <executorch/runtime/platform/log.h>
@@ -124,6 +125,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
124125
constexpr static int MODEL_TYPE_CATEGORY_LLM = 1;
125126
constexpr static int MODEL_TYPE_CATEGORY_MULTIMODAL = 2;
126127
constexpr static int MODEL_TYPE_MEDIATEK_LLAMA = 3;
128+
constexpr static int MODEL_TYPE_QNN_LLAMA = 4;
127129

128130
static facebook::jni::local_ref<jhybriddata> initHybrid(
129131
facebook::jni::alias_ref<jclass>,
@@ -174,6 +176,14 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
174176
model_path->toStdString(),
175177
tokenizer_path->toStdString(),
176178
data_path_str);
179+
} else if (model_type_category == MODEL_TYPE_QNN_LLAMA) {
180+
std::string decoder_model = "llama3"; // use llama3 for now
181+
runner_ = std::make_unique<example::Runner>( // QNN runner
182+
decoder_model.c_str(),
183+
model_path->toStdString().c_str(),
184+
tokenizer_path->toStdString().c_str(),
185+
data_path->toStdString().c_str());
186+
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
177187
#if defined(EXECUTORCH_BUILD_MEDIATEK)
178188
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
179189
runner_ = std::make_unique<MTKLlamaRunner>(
@@ -318,6 +328,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
318328
[callback](std::string result) { callback->onResult(result); },
319329
[callback](const llm::Stats& stats) { callback->onStats(stats); }));
320330
}
331+
return static_cast<jint>(executorch::runtime::Error::InvalidArgument);
321332
}
322333

323334
void stop() {

0 commit comments

Comments
 (0)