1313#include < unordered_map>
1414#include < vector>
1515
16- #include < executorch/examples/mediatek/executor_runner/mtk_llama_runner.h>
1716#include < executorch/examples/models/llama/runner/runner.h>
1817#include < executorch/examples/models/llava/runner/llava_runner.h>
1918#include < executorch/extension/llm/runner/image.h>
19+ #include < executorch/extension/llm/runner/runner_interface.h>
2020#include < executorch/runtime/platform/log.h>
2121#include < executorch/runtime/platform/platform.h>
2222#include < executorch/runtime/platform/runtime.h>
2929#include < fbjni/ByteBuffer.h>
3030#include < fbjni/fbjni.h>
3131
32+ #if defined(EXECUTORCH_BUILD_MEDIATEK)
33+ #include < executorch/examples/mediatek/executor_runner/mtk_llama_runner.h>
34+ #endif
35+
3236namespace llm = ::executorch::extension::llm;
3337using ::executorch::runtime::Error;
3438
@@ -112,9 +116,8 @@ class ExecuTorchLlamaJni
112116 private:
113117 friend HybridBase;
114118 int model_type_category_;
115- std::unique_ptr<example::Runner > runner_;
119+ std::unique_ptr<llm::RunnerInterface > runner_;
116120 std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
117- std::unique_ptr<MTKLlamaRunner> mtk_llama_runner_;
118121
119122 public:
120123 constexpr static auto kJavaDescriptor =
@@ -161,11 +164,15 @@ class ExecuTorchLlamaJni
161164 model_path->toStdString ().c_str (),
162165 tokenizer_path->toStdString ().c_str (),
163166 temperature);
167+ #if defined(EXECUTORCH_BUILD_MEDIATEK)
164168 } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
165- mtk_llama_runner_ = std::make_unique<MTKLlamaRunner>(
169+ runner_ = std::make_unique<MTKLlamaRunner>(
166170 model_path->toStdString ().c_str (),
167171 tokenizer_path->toStdString ().c_str (),
168172 temperature);
173+ // Interpret the model type as LLM
174+ model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
175+ #endif
169176 }
170177 }
171178
@@ -205,12 +212,6 @@ class ExecuTorchLlamaJni
205212 [callback](std::string result) { callback->onResult (result); },
206213 [callback](const llm::Stats& result) { callback->onStats (result); },
207214 echo);
208- } else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
209- mtk_llama_runner_->generate (
210- prompt->toStdString (),
211- seq_len,
212- [callback](std::string result) { callback->onResult (result); },
213- [callback](const Stats& result) { callback->onStats (result); });
214215 }
215216 return 0 ;
216217 }
@@ -300,8 +301,6 @@ class ExecuTorchLlamaJni
300301 multi_modal_runner_->stop ();
301302 } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
302303 runner_->stop ();
303- } else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
304- mtk_llama_runner_->stop ();
305304 }
306305 }
307306
@@ -310,8 +309,6 @@ class ExecuTorchLlamaJni
310309 return static_cast <jint>(multi_modal_runner_->load ());
311310 } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
312311 return static_cast <jint>(runner_->load ());
313- } else if (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
314- return static_cast <jint>(mtk_llama_runner_->load ());
315312 }
316313 return static_cast <jint>(Error::InvalidArgument);
317314 }
0 commit comments