1515#include  < unordered_map> 
1616#include  < vector> 
1717
18- #include  < executorch/examples/mediatek/executor_runner/mtk_llama_runner.h> 
1918#include  < executorch/examples/models/llama/runner/runner.h> 
2019#include  < executorch/examples/models/llava/runner/llava_runner.h> 
2120#include  < executorch/extension/llm/runner/image.h> 
21+ #include  < executorch/extension/llm/runner/runner_interface.h> 
2222#include  < executorch/runtime/platform/log.h> 
2323#include  < executorch/runtime/platform/platform.h> 
2424#include  < executorch/runtime/platform/runtime.h> 
3131#include  < fbjni/ByteBuffer.h> 
3232#include  < fbjni/fbjni.h> 
3333
34+ #if  defined(EXECUTORCH_BUILD_MEDIATEK)
35+ #include  < executorch/examples/mediatek/executor_runner/mtk_llama_runner.h> 
36+ #endif 
37+ 
3438namespace  llm  =  ::executorch::extension::llm;
3539using  ::executorch::runtime::Error;
3640
@@ -68,9 +72,8 @@ class ExecuTorchLlamaJni
6872 private: 
6973  friend  HybridBase;
7074  int  model_type_category_;
71-   std::unique_ptr<example::Runner > runner_;
75+   std::unique_ptr<llm::RunnerInterface > runner_;
7276  std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
73-   std::unique_ptr<MTKLlamaRunner> mtk_llama_runner_;
7477
7578 public: 
7679  constexpr  static  auto  kJavaDescriptor  =
@@ -117,11 +120,15 @@ class ExecuTorchLlamaJni
117120          model_path->toStdString ().c_str (),
118121          tokenizer_path->toStdString ().c_str (),
119122          temperature);
123+ #if  defined(EXECUTORCH_BUILD_MEDIATEK)
120124    } else  if  (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) {
121-       mtk_llama_runner_  = std::make_unique<MTKLlamaRunner>(
125+       runner_  = std::make_unique<MTKLlamaRunner>(
122126          model_path->toStdString ().c_str (),
123127          tokenizer_path->toStdString ().c_str (),
124128          temperature);
129+       //  Interpret the model type as LLM
130+       model_type_category_ = MODEL_TYPE_CATEGORY_LLM;
131+ #endif
125132    }
126133  }
127134
@@ -161,12 +168,6 @@ class ExecuTorchLlamaJni
161168          [callback](std::string result) { callback->onResult (result); },
162169          [callback](const  llm::Stats& result) { callback->onStats (result); },
163170          echo);
164-     } else  if  (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
165-       mtk_llama_runner_->generate (
166-           prompt->toStdString (),
167-           seq_len,
168-           [callback](std::string result) { callback->onResult (result); },
169-           [callback](const  Stats& result) { callback->onStats (result); });
170171    }
171172    return  0 ;
172173  }
@@ -256,8 +257,6 @@ class ExecuTorchLlamaJni
256257      multi_modal_runner_->stop ();
257258    } else  if  (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
258259      runner_->stop ();
259-     } else  if  (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
260-       mtk_llama_runner_->stop ();
261260    }
262261  }
263262
@@ -266,8 +265,6 @@ class ExecuTorchLlamaJni
266265      return  static_cast <jint>(multi_modal_runner_->load ());
267266    } else  if  (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
268267      return  static_cast <jint>(runner_->load ());
269-     } else  if  (model_type_category_ == MODEL_TYPE_MEDIATEK_LLAMA) {
270-       return  static_cast <jint>(mtk_llama_runner_->load ());
271268    }
272269    return  static_cast <jint>(Error::InvalidArgument);
273270  }
0 commit comments