diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java index 32ed33cd302..cf7ab1756ce 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModelUtils.java @@ -21,6 +21,9 @@ public class ModelUtils { // MediaTek static final int MEDIATEK_TEXT_MODEL = 3; + // QNN static llama + static final int QNN_TEXT_MODEL = 4; + public static int getModelCategory(ModelType modelType, BackendType backendType) { if (backendType.equals(BackendType.XNNPACK)) { switch (modelType) { @@ -35,6 +38,8 @@ public static int getModelCategory(ModelType modelType, BackendType backendType) } } else if (backendType.equals(BackendType.MEDIATEK)) { return MEDIATEK_TEXT_MODEL; + } else if (backendType.equals(BackendType.QUALCOMM)) { + return QNN_TEXT_MODEL; } return TEXT_MODEL; // default diff --git a/examples/models/llama/runner/runner.h b/examples/models/llama/runner/runner.h index 09a166b0109..f07cd4e8ee8 100644 --- a/examples/models/llama/runner/runner.h +++ b/examples/models/llama/runner/runner.h @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -33,6 +34,7 @@ std::unique_ptr create_llama_runner( float temperature = -1.0f); std::unique_ptr load_llama_tokenizer( - const std::string& tokenizer_path); + const std::string& tokenizer_path, + Version version = Version::Default); } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/CMakeLists.txt b/examples/qualcomm/oss_scripts/llama/CMakeLists.txt index bf83a456bca..78a7e2905e6 100644 --- a/examples/qualcomm/oss_scripts/llama/CMakeLists.txt +++ b/examples/qualcomm/oss_scripts/llama/CMakeLists.txt @@ -42,6 +42,8 @@ list( ${CMAKE_CURRENT_LIST_DIR}/runner/rpc_mem.h ${CMAKE_CURRENT_LIST_DIR}/runner/kv_manager.cpp ${CMAKE_CURRENT_LIST_DIR}/runner/kv_manager.h + ${EXECUTORCH_SOURCE_DIR}/examples/models/llama/runner/runner.cpp + ${EXECUTORCH_SOURCE_DIR}/examples/models/llama/runner/runner.h ) list(APPEND _llama_runner__srcs) diff --git a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp index 751271cf613..c0ad838f597 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -61,7 +62,7 @@ DEFINE_int32( "Total number of tokens to generate (prompt + output)."); DEFINE_int32( eval_mode, - 0, + 1, "0: TokenGenerator(kv) / 1: HybridMode (prefill+kv) / 2: Lookahead Decoding"); DEFINE_string( kv_updater, @@ -172,13 +173,17 @@ void start_runner( buf.push_back(c); } }; - + executorch::extension::llm::GenerationConfig config{ + true, + -1, + false, + FLAGS_seq_len, + static_cast(FLAGS_temperature), + 0, + 0}; if (use_tokenized_prompt) { - runner.generate( - FLAGS_tokenized_prompt.c_str(), - use_tokenized_prompt, - FLAGS_seq_len, - callback); + runner.generate_from_prompt_or_file( + FLAGS_tokenized_prompt.c_str(), use_tokenized_prompt, config, callback); } else { // generate tokens & store inference output for (int i = 0; i < FLAGS_num_iters; i++) { @@ -186,11 +191,8 @@ void start_runner( std::string formatted_prompt; formatted_prompt = get_formatted_prompt( prompt, FLAGS_system_prompt, decoder_model_version.get()); - runner.generate( - formatted_prompt.c_str(), - use_tokenized_prompt, - FLAGS_seq_len, - callback); + runner.generate_from_prompt_or_file( + formatted_prompt.c_str(), use_tokenized_prompt, config, callback); } } } diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index fc38129c1d1..a0de66f6f69 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -9,6 +9,7 @@ // A llama 3.2 runner that includes preprocessing and post processing // logic. The module takes in a string as input and emits a string as output. +#include #include #include #include @@ -58,7 +59,7 @@ void print_performance_report( outfile << num_tok; outfile.close(); } else { - ET_CHECK_MSG(false, "Error saving the inference speed file"); + ET_LOG(Error, "Error saving the inference speed file"); } } @@ -83,13 +84,6 @@ void save_logits( } // namespace -std::unique_ptr<::tokenizers::Tokenizer> load_llama_tokenizer( - const std::string& tokenizer_path, - Version version) { - auto special_tokens = get_special_tokens(version); - return llm::load_tokenizer(tokenizer_path, std::move(special_tokens)); -} - template Runner::Runner( std::unique_ptr module, @@ -181,7 +175,8 @@ Error Runner::load() { eos_ids->insert(tokenizer_->encode("<|eot|>", 0, 0).get()[0]); eos_ids->insert(tokenizer_->encode("<|end_of_text|>", 0, 0).get()[0]); } else { - tokenizer_ = load_llama_tokenizer(tokenizer_path_, Version::Default); + tokenizer_ = + example::load_llama_tokenizer(tokenizer_path_, Version::Default); if (tokenizer_ == nullptr) { ET_LOG( Error, "Failed to load tokenizer with %s", tokenizer_path_.c_str()); @@ -323,13 +318,32 @@ Error Runner::load() { template Error Runner::generate( + const std::string& prompt, + const llm::GenerationConfig& config, + std::function token_callback, + std::function stats_callback) { + return generate_from_pos(prompt, 0, config, token_callback, stats_callback); +} + +template +Error Runner::generate_from_pos( + const std::string& prompt, + int64_t start_pos, + const llm::GenerationConfig& config, + std::function token_callback, + std::function stats_callback) { + // TODO: currently only support start_pos == 0 + return generate_from_prompt_or_file( + prompt, false, config, token_callback, stats_callback); +} + +template +Error Runner::generate_from_prompt_or_file( const std::string& prompt, bool tokenized_prompt, - int32_t seq_len, + const llm::GenerationConfig& config, std::function token_callback, - std::function stats_callback, - bool echo, - bool warming) { + std::function stats_callback) { ET_CHECK_MSG(!prompt.empty(), "prompt cannot be null"); if (!is_loaded()) { stats_.model_load_start_ms = time_in_ms(); @@ -338,6 +352,7 @@ Error Runner::generate( } stats_.inference_start_ms = time_in_ms(); + int32_t seq_len = config.seq_len; seq_len = (seq_len > 0 && seq_len <= context_len_) ? seq_len : context_len_; int32_t n_bos = (cur_pos_ == 0) ? 1 : 0; @@ -376,7 +391,7 @@ Error Runner::generate( "sequence length exceeded - please increase the seq_len value"); // Prompt Processor first - if (token_callback) { + if (token_callback && config.echo) { token_callback(prompt); } bool dump_logits = dump_logits_path_.empty() ? false : true; diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.h b/examples/qualcomm/oss_scripts/llama/runner/runner.h index 14f415f7fc6..a4a8bb2efcb 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -41,7 +42,7 @@ enum KvBitWidth { }; template -class Runner { +class Runner : public executorch::extension::llm::IRunner { public: explicit Runner( std::unique_ptr module, @@ -51,25 +52,36 @@ class Runner { const std::string& performance_output_path, const std::string& dump_logits_path, const float temperature = 0.8f, - const int eval_mode = EvalMode::kKVCached, + const int eval_mode = EvalMode::kHybrid, const std::string& kv_updater = "SmartMask", const int ngram = 0, const int window = 0, const int gcap = 0, std::unique_ptr tokenizer = nullptr); - bool is_loaded() const; - executorch::runtime::Error load(); + bool is_loaded() const override; + executorch::runtime::Error load() override; // TODO: Support echo and warming executorch::runtime::Error generate( + const std::string& prompt, + const executorch::extension::llm::GenerationConfig& config, + std::function token_callback = {}, + std::function stats_callback = {}) + override; + executorch::runtime::Error generate_from_pos( + const std::string& prompt, + int64_t start_pos, + const executorch::extension::llm::GenerationConfig& config, + std::function token_callback = {}, + std::function stats_callback = {}) + override; + executorch::runtime::Error generate_from_prompt_or_file( const std::string& prompt, bool tokenized_prompt, - int32_t seq_len, + const executorch::extension::llm::GenerationConfig& config, std::function token_callback = {}, - std::function stats_callback = {}, - bool echo = true, - bool warming = false); - void stop() {}; + std::function stats_callback = {}); + void stop() override {}; executorch::runtime::Result get_decoder_model_version(); private: diff --git a/examples/qualcomm/oss_scripts/llama/targets.bzl b/examples/qualcomm/oss_scripts/llama/targets.bzl index b70c8fd2f33..062edf7594c 100644 --- a/examples/qualcomm/oss_scripts/llama/targets.bzl +++ b/examples/qualcomm/oss_scripts/llama/targets.bzl @@ -29,6 +29,7 @@ def define_common_targets(): exported_deps = [ "//executorch/extension/module:module", "//executorch/extension/llm/sampler:sampler", + "//executorch/examples/models/llama/runner:runner", "//executorch/examples/models/llama/tokenizer:tiktoken", "//executorch/extension/evalue_util:print_evalue", "//executorch/backends/qualcomm/runtime:runtime", diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index c1fb1125c3e..38d30854525 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -179,6 +179,35 @@ if(EXECUTORCH_BUILD_LLAMA_JNI) ${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llama/runner ) + target_sources( + executorch_jni + PRIVATE ${EXECUTORCH_ROOT}/extension/llm/runner/llm_runner_helper.cpp + ) + + target_include_directories( + executorch_jni PRIVATE ${EXECUTORCH_ROOT}/extension/llm/runner + ) + + if(QNN_SDK_ROOT) + target_sources( + executorch_jni + PRIVATE + ${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/runner.cpp + ${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/decoder_runner.cpp + ${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp + ${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp + ${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp + ${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/rpc_mem.cpp + ${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp + ) + + target_include_directories( + executorch_jni + PRIVATE ${EXECUTORCH_ROOT}/examples/qualcomm/oss_scripts/llama/runner + ) + target_compile_definitions(executorch_jni PRIVATE EXECUTORCH_BUILD_QNN=1) + endif() + if(NEURON_BUFFER_ALLOCATOR_LIB) target_sources( executorch_jni diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 48bc62141a2..a27b8194530 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -15,6 +15,7 @@ #include #include +#include #include #include #include @@ -29,6 +30,10 @@ #include #include +#if defined(EXECUTORCH_BUILD_QNN) +#include +#endif + #if defined(EXECUTORCH_BUILD_MEDIATEK) #include #endif @@ -124,6 +129,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { constexpr static int MODEL_TYPE_CATEGORY_LLM = 1; constexpr static int MODEL_TYPE_CATEGORY_MULTIMODAL = 2; constexpr static int MODEL_TYPE_MEDIATEK_LLAMA = 3; + constexpr static int MODEL_TYPE_QNN_LLAMA = 4; static facebook::jni::local_ref initHybrid( facebook::jni::alias_ref, @@ -174,6 +180,22 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { model_path->toStdString(), tokenizer_path->toStdString(), data_path_str); +#if defined(EXECUTORCH_BUILD_QNN) + } else if (model_type_category == MODEL_TYPE_QNN_LLAMA) { + std::unique_ptr module = std::make_unique< + executorch::extension::Module>( + model_path->toStdString().c_str(), + executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors); + std::string decoder_model = "llama3"; // use llama3 for now + runner_ = std::make_unique>( // QNN runner + std::move(module), + decoder_model.c_str(), + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str(), + data_path->toStdString().c_str(), + ""); + model_type_category_ = MODEL_TYPE_CATEGORY_LLM; +#endif #if defined(EXECUTORCH_BUILD_MEDIATEK) } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) { runner_ = std::make_unique( @@ -318,6 +340,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { [callback](std::string result) { callback->onResult(result); }, [callback](const llm::Stats& stats) { callback->onStats(stats); })); } + return static_cast(executorch::runtime::Error::InvalidArgument); } void stop() {