1313#include < unordered_map>
1414#include < vector>
1515
16- #include < executorch/examples/models/llama/runner/runner.h>
17- #include < executorch/examples/models/llava/runner/llava_runner.h>
1816#include < executorch/extension/llm/runner/image.h>
1917#include < executorch/extension/llm/runner/irunner.h>
18+ #include < executorch/extension/llm/runner/llm_runner_helper.h>
19+ #include < executorch/extension/llm/runner/multimodal_input.h>
20+ #include < executorch/extension/llm/runner/multimodal_runner.h>
21+ #include < executorch/extension/llm/runner/text_llm_runner.h>
2022#include < executorch/runtime/platform/log.h>
2123#include < executorch/runtime/platform/platform.h>
2224#include < executorch/runtime/platform/runtime.h>
@@ -119,7 +121,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
119121 float temperature_ = 0 .0f ;
120122 int model_type_category_;
121123 std::unique_ptr<llm::IRunner> runner_;
122- std::unique_ptr<example::LlavaRunner > multi_modal_runner_;
124+ std::unique_ptr<executorch::extension::llm::MultimodalRunner > multi_modal_runner_;
123125
124126 public:
125127 constexpr static auto kJavaDescriptor =
@@ -165,19 +167,16 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
165167
166168 model_type_category_ = model_type_category;
167169 if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) {
168- multi_modal_runner_ = std::make_unique<example::LlavaRunner> (
170+ multi_modal_runner_ = llm::create_multimodal_runner (
169171 model_path->toStdString ().c_str (),
170- tokenizer_path->toStdString ().c_str (),
171- temperature);
172+ llm::load_tokenizer (tokenizer_path->toStdString ()));
172173 } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
173174 std::optional<const std::string> data_path_str = data_path
174175 ? std::optional<const std::string>{data_path->toStdString ()}
175176 : std::nullopt ;
176- // TODO(larryliu0820): Use the API in text_llm_runner.h to create the
177- // runner.
178- runner_ = example::create_llama_runner (
177+ runner_ = executorch::extension::llm::create_text_llm_runner (
179178 model_path->toStdString (),
180- tokenizer_path->toStdString (),
179+ llm::load_tokenizer ( tokenizer_path->toStdString () ),
181180 data_path_str);
182181#if defined(EXECUTORCH_BUILD_QNN)
183182 } else if (model_type_category == MODEL_TYPE_QNN_LLAMA) {
@@ -260,17 +259,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
260259 jint eos) {
261260 facebook::jni::local_ref<jlongArray> tuple_result =
262261 facebook::jni::make_long_array (2 );
263- if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
264- tuple_result->pin ()[0 ] = static_cast <jint>(Error::NotSupported);
265- return tuple_result;
266- }
267-
268- auto && result = multi_modal_runner_->prefill_prompt (
269- prompt->toStdString (), start_pos, bos, eos);
270- tuple_result->pin ()[0 ] = static_cast <jint>(Error::Ok);
271- if (result.ok ()) {
272- tuple_result->pin ()[1 ] = static_cast <jlong>(start_pos);
273- }
262+ tuple_result->pin ()[0 ] = static_cast <jint>(Error::NotSupported);
274263 return tuple_result;
275264 }
276265
@@ -287,28 +276,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
287276 facebook::jni::local_ref<jlongArray> tuple_result =
288277 facebook::jni::make_long_array (2 );
289278
290- if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
291- tuple_result->pin ()[0 ] = static_cast <jint>(Error::NotSupported);
292- return tuple_result;
293- }
294-
295- auto image_size = image->size ();
296- std::vector<llm::Image> images;
297- if (image_size != 0 ) {
298- std::vector<jint> image_data_jint (image_size);
299- std::vector<uint8_t > image_data (image_size);
300- image->getRegion (0 , image_size, image_data_jint.data ());
301- for (int i = 0 ; i < image_size; i++) {
302- image_data[i] = image_data_jint[i];
303- }
304- llm::Image image_runner{image_data, width, height, channels};
305- images.push_back (image_runner);
306- }
307- // TODO(hsz): make start_pos a reference and update it here
308- jint result = static_cast <jint>(
309- multi_modal_runner_->prefill_images (images, start_pos));
310- tuple_result->pin ()[0 ] = result;
311- tuple_result->pin ()[1 ] = static_cast <jlong>(start_pos);
279+ tuple_result->pin ()[0 ] = static_cast <jint>(Error::NotSupported);
312280 return tuple_result;
313281 }
314282
@@ -319,13 +287,12 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
319287 facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
320288 jboolean echo) {
321289 if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
322- return static_cast <jint>(multi_modal_runner_-> generate_from_pos (
323- prompt-> toStdString (),
324- seq_len ,
325- start_pos ,
290+
291+ return static_cast <jint>(multi_modal_runner_-> generate (
292+ std::vector<llm::MultimodalInput>{llm::MultimodalInput{prompt-> toStdString ()}} ,
293+ llm::GenerationConfig {. echo = static_cast < bool >(echo), . seq_len = seq_len} ,
326294 [callback](const std::string& result) { callback->onResult (result); },
327- [callback](const llm::Stats& stats) { callback->onStats (stats); },
328- echo));
295+ [callback](const llm::Stats& stats) { callback->onStats (stats); }));
329296 } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
330297 executorch::extension::llm::GenerationConfig config{
331298 .echo = static_cast <bool >(echo),
0 commit comments