1313#include  < unordered_map> 
1414#include  < vector> 
1515
16- #include  < executorch/examples/mediatek/executor_runner/mtk_llama_runner.h> 
1716#include  < executorch/examples/models/llama/runner/runner.h> 
1817#include  < executorch/examples/models/llava/runner/llava_runner.h> 
1918#include  < executorch/extension/llm/runner/image.h> 
19+ #include  < executorch/extension/llm/runner/runner_interface.h> 
2020#include  < executorch/runtime/platform/log.h> 
2121#include  < executorch/runtime/platform/platform.h> 
2222#include  < executorch/runtime/platform/runtime.h> 
2929#include  < fbjni/ByteBuffer.h> 
3030#include  < fbjni/fbjni.h> 
3131
32+ #if  defined(EXECUTORCH_BUILD_MEDIATEK)
33+ #include  < executorch/examples/mediatek/executor_runner/mtk_llama_runner.h> 
34+ #endif 
35+ 
3236namespace  llm  =  ::executorch::extension::llm;
3337using  ::executorch::runtime::Error;
3438
@@ -112,9 +116,8 @@ class ExecuTorchLlamaJni
112116 private: 
113117  friend  HybridBase;
114118  int  model_type_category_;
115-   std::unique_ptr<example::Runner > runner_;
119+   std::unique_ptr<llm::RunnerInterface > runner_;
116120  std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
117-   std::unique_ptr<MTKLlamaRunner> mtk_llama_runner_;
118121
119122 public: 
120123  constexpr  static  auto  kJavaDescriptor  =
@@ -161,11 +164,15 @@ class ExecuTorchLlamaJni
161164          model_path->toStdString ().c_str (),
162165          tokenizer_path->toStdString ().c_str (),
163166          temperature);
167+ #if  defined(EXECUTORCH_BUILD_MEDIATEK)
164168    } else  if  (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
165-       mtk_llama_runner_  = std::make_unique<MTKLlamaRunner>(
169+       runner_  = std::make_unique<MTKLlamaRunner>(
166170          model_path->toStdString ().c_str (),
167171          tokenizer_path->toStdString ().c_str (),
168172          temperature);
173+       //  Interpret the model type as LLM
174+       model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
175+ #endif
169176    }
170177  }
171178
@@ -205,12 +212,6 @@ class ExecuTorchLlamaJni
205212          [callback](std::string result) { callback->onResult (result); },
206213          [callback](const  llm::Stats& result) { callback->onStats (result); },
207214          echo);
208-     } else  if  (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
209-       mtk_llama_runner_->generate (
210-           prompt->toStdString (),
211-           seq_len,
212-           [callback](std::string result) { callback->onResult (result); },
213-           [callback](const  Stats& result) { callback->onStats (result); });
214215    }
215216    return  0 ;
216217  }
@@ -300,8 +301,6 @@ class ExecuTorchLlamaJni
300301      multi_modal_runner_->stop ();
301302    } else  if  (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
302303      runner_->stop ();
303-     } else  if  (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
304-       mtk_llama_runner_->stop ();
305304    }
306305  }
307306
@@ -310,8 +309,6 @@ class ExecuTorchLlamaJni
310309      return  static_cast <jint>(multi_modal_runner_->load ());
311310    } else  if  (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
312311      return  static_cast <jint>(runner_->load ());
313-     } else  if  (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
314-       return  static_cast <jint>(mtk_llama_runner_->load ());
315312    }
316313    return  static_cast <jint>(Error::InvalidArgument);
317314  }
0 commit comments