diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 25873a788b5..690da1b9e62 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -168,13 +168,13 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { } #endif + std::vector data_files_vector; model_type_category_ = model_type_category; if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) { multi_modal_runner_ = llm::create_multimodal_runner( model_path->toStdString().c_str(), llm::load_tokenizer(tokenizer_path->toStdString())); } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) { - std::vector data_files_vector; if (data_files != nullptr) { // Convert Java List to C++ std::vector auto list_class = facebook::jni::findClassStatic("java/util/List"); @@ -202,12 +202,38 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { data_files_vector, executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors); std::string decoder_model = "llama3"; // use llama3 for now - runner_ = std::make_unique>( // QNN runner - std::move(module), - decoder_model.c_str(), - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), - ""); + example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8; + if (module->method_names()->count("get_kv_io_bit_width") > 0) { + kv_bitwidth = static_cast( + module->get("get_kv_io_bit_width").get().toScalar().to()); + } + + if (kv_bitwidth == example::KvBitWidth::kWidth8) { + runner_ = std::make_unique>( + std::move(module), + decoder_model.c_str(), // const std::string& + model_path->toStdString().c_str(), // const std::string& + tokenizer_path->toStdString().c_str(), // const std::string& + "", // performance_output_path + "", // dump_logits_path + temperature_ // temperature + ); + } else if (kv_bitwidth == example::KvBitWidth::kWidth16) { + runner_ = std::make_unique>( + std::move(module), + decoder_model.c_str(), // const std::string& + model_path->toStdString().c_str(), // const std::string& + tokenizer_path->toStdString().c_str(), // const std::string& + "", // performance_output_path + "", // dump_logits_path + temperature_ // temperature + ); + } else { + ET_CHECK_MSG( + false, + "Unsupported kv bitwidth: %ld", + static_cast(kv_bitwidth)); + } model_type_category_ = MODEL_TYPE_CATEGORY_LLM; #endif #if defined(EXECUTORCH_BUILD_MEDIATEK)