Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions extension/android/jni/jni_layer_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,13 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
}
#endif

std::vector<std::string> data_files_vector;
model_type_category_ = model_type_category;
if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) {
multi_modal_runner_ = llm::create_multimodal_runner(
model_path->toStdString().c_str(),
llm::load_tokenizer(tokenizer_path->toStdString()));
} else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
std::vector<std::string> data_files_vector;
if (data_files != nullptr) {
// Convert Java List<String> to C++ std::vector<string>
auto list_class = facebook::jni::findClassStatic("java/util/List");
Expand Down Expand Up @@ -202,12 +202,38 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
data_files_vector,
executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors);
std::string decoder_model = "llama3"; // use llama3 for now
Copy link
Contributor Author

@cccclai cccclai Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace the model name here accordingly, available path is

if (decoder_model_version == "llama2") {

runner_ = std::make_unique<example::Runner<uint16_t>>( // QNN runner
std::move(module),
decoder_model.c_str(),
model_path->toStdString().c_str(),
tokenizer_path->toStdString().c_str(),
"");
example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8;
if (module->method_names()->count("get_kv_io_bit_width") > 0) {
kv_bitwidth = static_cast<example::KvBitWidth>(
module->get("get_kv_io_bit_width").get().toScalar().to<int64_t>());
}

if (kv_bitwidth == example::KvBitWidth::kWidth8) {
runner_ = std::make_unique<example::Runner<uint8_t>>(
std::move(module),
decoder_model.c_str(), // const std::string&
model_path->toStdString().c_str(), // const std::string&
tokenizer_path->toStdString().c_str(), // const std::string&
"", // performance_output_path
"", // dump_logits_path
temperature_ // temperature
);
} else if (kv_bitwidth == example::KvBitWidth::kWidth16) {
runner_ = std::make_unique<example::Runner<uint16_t>>(
std::move(module),
decoder_model.c_str(), // const std::string&
model_path->toStdString().c_str(), // const std::string&
tokenizer_path->toStdString().c_str(), // const std::string&
"", // performance_output_path
"", // dump_logits_path
temperature_ // temperature
);
} else {
ET_CHECK_MSG(
false,
"Unsupported kv bitwidth: %ld",
static_cast<int64_t>(kv_bitwidth));
}
model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
#endif
#if defined(EXECUTORCH_BUILD_MEDIATEK)
Expand Down
Loading