Skip to content

Commit 826d59d

Browse files
cmodi-metakirklandsign
authored andcommitted
Enable JNI with MTK Llama Runner core functions
1 parent 8bf19d1 commit 826d59d

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

extension/android/jni/jni_layer_llama.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include <executorch/examples/models/llama/runner/runner.h>
1717
#include <executorch/examples/models/llava/runner/llava_runner.h>
18+
#include <executorch/examples/mediatek/executor_runner/mtk_llama_runner.h>
1819
#include <executorch/extension/llm/runner/image.h>
1920
#include <executorch/runtime/platform/log.h>
2021
#include <executorch/runtime/platform/platform.h>
@@ -113,13 +114,15 @@ class ExecuTorchLlamaJni
113114
int model_type_category_;
114115
std::unique_ptr<example::Runner> runner_;
115116
std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
117+
std::unique_ptr<MTKLlamaRunner> mtk_llama_runner_;
116118

117119
public:
118120
constexpr static auto kJavaDescriptor =
119121
"Lorg/pytorch/executorch/LlamaModule;";
120122

121123
constexpr static int MODEL_TYPE_CATEGORY_LLM = 1;
122124
constexpr static int MODEL_TYPE_CATEGORY_MULTIMODAL = 2;
125+
constexpr static int MODEL_TYPE_MEDIATEK_LLAMA = 3;
123126

124127
static facebook::jni::local_ref<jhybriddata> initHybrid(
125128
facebook::jni::alias_ref<jclass>,
@@ -158,6 +161,11 @@ class ExecuTorchLlamaJni
158161
model_path->toStdString().c_str(),
159162
tokenizer_path->toStdString().c_str(),
160163
temperature);
164+
} else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
165+
mtk_llama_runner_ = std::make_unique<MTKLlamaRunner>(
166+
model_path->toStdString().c_str(),
167+
tokenizer_path->toStdString().c_str(),
168+
temperature);
161169
}
162170
}
163171

@@ -197,6 +205,12 @@ class ExecuTorchLlamaJni
197205
[callback](std::string result) { callback->onResult(result); },
198206
[callback](const llm::Stats& result) { callback->onStats(result); },
199207
echo);
208+
} else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
209+
mtk_llama_runner_->generate(
210+
prompt->toStdString(),
211+
seq_len,
212+
[callback](std::string result) { callback->onResult(result); },
213+
[callback](const Stats& result) { callback->onStats(result); });
200214
}
201215
return 0;
202216
}
@@ -286,6 +300,8 @@ class ExecuTorchLlamaJni
286300
multi_modal_runner_->stop();
287301
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
288302
runner_->stop();
303+
} else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
304+
mtk_llama_runner_->stop();
289305
}
290306
}
291307

@@ -294,6 +310,8 @@ class ExecuTorchLlamaJni
294310
return static_cast<jint>(multi_modal_runner_->load());
295311
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
296312
return static_cast<jint>(runner_->load());
313+
} else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
314+
return static_cast<jint>(mtk_llama_runner_->load());
297315
}
298316
return static_cast<jint>(Error::InvalidArgument);
299317
}

0 commit comments

Comments
 (0)