1515
1616#include < executorch/examples/models/llama/runner/runner.h>
1717#include < executorch/examples/models/llava/runner/llava_runner.h>
18+ #include < executorch/examples/mediatek/executor_runner/mtk_llama_runner.h>
1819#include < executorch/extension/llm/runner/image.h>
1920#include < executorch/runtime/platform/log.h>
2021#include < executorch/runtime/platform/platform.h>
@@ -113,13 +114,15 @@ class ExecuTorchLlamaJni
113114 int model_type_category_;
114115 std::unique_ptr<example::Runner> runner_;
115116 std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
117+ std::unique_ptr<MTKLlamaRunner> mtk_llama_runner_;
116118
117119 public:
118120 constexpr static auto kJavaDescriptor =
119121 " Lorg/pytorch/executorch/LlamaModule;" ;
120122
121123 constexpr static int MODEL_TYPE_CATEGORY_LLM = 1 ;
122124 constexpr static int MODEL_TYPE_CATEGORY_MULTIMODAL = 2 ;
125+ constexpr static int MODEL_TYPE_MEDIATEK_LLAMA = 3 ;
123126
124127 static facebook::jni::local_ref<jhybriddata> initHybrid (
125128 facebook::jni::alias_ref<jclass>,
@@ -158,6 +161,11 @@ class ExecuTorchLlamaJni
158161 model_path->toStdString ().c_str (),
159162 tokenizer_path->toStdString ().c_str (),
160163 temperature);
164+ } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
165+ mtk_llama_runner_ = std::make_unique<MTKLlamaRunner>(
166+ model_path->toStdString ().c_str (),
167+ tokenizer_path->toStdString ().c_str (),
168+ temperature);
161169 }
162170 }
163171
@@ -197,6 +205,12 @@ class ExecuTorchLlamaJni
197205 [callback](std::string result) { callback->onResult (result); },
198206 [callback](const llm::Stats& result) { callback->onStats (result); },
199207 echo);
208+ } else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
209+ mtk_llama_runner_->generate (
210+ prompt->toStdString (),
211+ seq_len,
212+ [callback](std::string result) { callback->onResult (result); },
213+ [callback](const Stats& result) { callback->onStats (result); });
200214 }
201215 return 0 ;
202216 }
@@ -286,6 +300,8 @@ class ExecuTorchLlamaJni
286300 multi_modal_runner_->stop ();
287301 } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
288302 runner_->stop ();
303+ } else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
304+ mtk_llama_runner_->stop ();
289305 }
290306 }
291307
@@ -294,6 +310,8 @@ class ExecuTorchLlamaJni
294310 return static_cast <jint>(multi_modal_runner_->load ());
295311 } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
296312 return static_cast <jint>(runner_->load ());
313+ } else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
314+ return static_cast <jint>(mtk_llama_runner_->load ());
297315 }
298316 return static_cast <jint>(Error::InvalidArgument);
299317 }
0 commit comments