1717
1818#include < executorch/examples/models/llama2/runner/runner.h>
1919#include < executorch/examples/models/llava/runner/llava_runner.h>
20+ #include < executorch/examples/mediatek/executor_runner/mtk_llama_runner.h>
2021#include < executorch/extension/llm/runner/image.h>
2122#include < executorch/runtime/platform/log.h>
2223#include < executorch/runtime/platform/platform.h>
@@ -69,13 +70,15 @@ class ExecuTorchLlamaJni
6970 int model_type_category_;
7071 std::unique_ptr<example::Runner> runner_;
7172 std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
73+ std::unique_ptr<MTKLlamaRunner> mtk_llama_runner_;
7274
7375 public:
7476 constexpr static auto kJavaDescriptor =
7577 " Lorg/pytorch/executorch/LlamaModule;" ;
7678
7779 constexpr static int MODEL_TYPE_CATEGORY_LLM = 1 ;
7880 constexpr static int MODEL_TYPE_CATEGORY_MULTIMODAL = 2 ;
81+ constexpr static int MODEL_TYPE_MEDIATEK_LLAMA = 3 ;
7982
8083 static facebook::jni::local_ref<jhybriddata> initHybrid (
8184 facebook::jni::alias_ref<jclass>,
@@ -114,6 +117,11 @@ class ExecuTorchLlamaJni
114117 model_path->toStdString ().c_str (),
115118 tokenizer_path->toStdString ().c_str (),
116119 temperature);
120+ } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
121+ mtk_llama_runner_ = std::make_unique<MTKLlamaRunner>(
122+ model_path->toStdString ().c_str (),
123+ tokenizer_path->toStdString ().c_str (),
124+ temperature);
117125 }
118126 }
119127
@@ -153,6 +161,12 @@ class ExecuTorchLlamaJni
153161 [callback](std::string result) { callback->onResult (result); },
154162 [callback](const llm::Stats& result) { callback->onStats (result); },
155163 echo);
164+ } else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
165+ mtk_llama_runner_->generate (
166+ prompt->toStdString (),
167+ seq_len,
168+ [callback](std::string result) { callback->onResult (result); },
169+ [callback](const Stats& result) { callback->onStats (result); });
156170 }
157171 return 0 ;
158172 }
@@ -242,6 +256,8 @@ class ExecuTorchLlamaJni
242256 multi_modal_runner_->stop ();
243257 } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
244258 runner_->stop ();
259+ } else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
260+ mtk_llama_runner_->stop ();
245261 }
246262 }
247263
@@ -250,6 +266,8 @@ class ExecuTorchLlamaJni
250266 return static_cast <jint>(multi_modal_runner_->load ());
251267 } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
252268 return static_cast <jint>(runner_->load ());
269+ } else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
270+ return static_cast <jint>(mtk_llama_runner_->load ());
253271 }
254272 return static_cast <jint>(Error::InvalidArgument);
255273 }
0 commit comments