1313#include < unordered_map>
1414#include < vector>
1515
16- #include < executorch/examples/models/llava/runner/llava_runner.h>
1716#include < executorch/extension/llm/runner/image.h>
1817#include < executorch/extension/llm/runner/irunner.h>
1918#include < executorch/extension/llm/runner/llm_runner_helper.h>
@@ -122,7 +121,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
122121 float temperature_ = 0 .0f ;
123122 int model_type_category_;
124123 std::unique_ptr<llm::IRunner> runner_;
125- std::unique_ptr<example::LlavaRunner> multi_modal_runner_;
124+ std::unique_ptr<executorch::extension::llm::MultimodalRunner>
125+ multi_modal_runner_;
126+ std::vector<llm::MultimodalInput> prefill_inputs_;
126127
127128 public:
128129 constexpr static auto kJavaDescriptor =
@@ -168,10 +169,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
168169
169170 model_type_category_ = model_type_category;
170171 if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) {
171- multi_modal_runner_ = std::make_unique<example::LlavaRunner> (
172+ multi_modal_runner_ = llm::create_multimodal_runner (
172173 model_path->toStdString ().c_str (),
173- tokenizer_path->toStdString ().c_str (),
174- temperature);
174+ llm::load_tokenizer (tokenizer_path->toStdString ()));
175175 } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
176176 std::optional<const std::string> data_path_str = data_path
177177 ? std::optional<const std::string>{data_path->toStdString ()}
@@ -217,6 +217,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
217217 facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
218218 jboolean echo) {
219219 if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
220+ std::vector<llm::MultimodalInput> inputs = prefill_inputs_;
221+ prefill_inputs_.clear ();
222+ inputs.emplace_back (llm::MultimodalInput{prompt->toStdString ()});
220223 auto image_size = image->size ();
221224 std::vector<llm::Image> images;
222225 if (image_size != 0 ) {
@@ -227,15 +230,18 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
227230 image_data[i] = image_data_jint[i];
228231 }
229232 llm::Image image_runner{image_data, width, height, channels};
230- images. push_back ( image_runner);
233+ inputs. emplace_back (llm::MultimodalInput{ std::move ( image_runner)} );
231234 }
235+ executorch::extension::llm::GenerationConfig config{
236+ .echo = static_cast <bool >(echo),
237+ .seq_len = seq_len,
238+ .temperature = temperature_,
239+ };
232240 multi_modal_runner_->generate (
233- std::move (images),
234- prompt->toStdString (),
235- seq_len,
236- [callback](std::string result) { callback->onResult (result); },
237- [callback](const llm::Stats& result) { callback->onStats (result); },
238- echo);
241+ std::move (inputs),
242+ config,
243+ [callback](const std::string& result) { callback->onResult (result); },
244+ [callback](const llm::Stats& result) { callback->onStats (result); });
239245 } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
240246 executorch::extension::llm::GenerationConfig config{
241247 .echo = static_cast <bool >(echo),
@@ -259,19 +265,10 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
259265 jlong start_pos,
260266 jint bos,
261267 jint eos) {
268+ prefill_inputs_.emplace_back (llm::MultimodalInput{prompt->toStdString ()});
262269 facebook::jni::local_ref<jlongArray> tuple_result =
263270 facebook::jni::make_long_array (2 );
264- if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
265- tuple_result->pin ()[0 ] = static_cast <jint>(Error::NotSupported);
266- return tuple_result;
267- }
268-
269- auto && result = multi_modal_runner_->prefill_prompt (
270- prompt->toStdString (), start_pos, bos, eos);
271271 tuple_result->pin ()[0 ] = static_cast <jint>(Error::Ok);
272- if (result.ok ()) {
273- tuple_result->pin ()[1 ] = static_cast <jlong>(start_pos);
274- }
275272 return tuple_result;
276273 }
277274
@@ -285,16 +282,8 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
285282 jint height,
286283 jint channels,
287284 jlong start_pos) {
288- facebook::jni::local_ref<jlongArray> tuple_result =
289- facebook::jni::make_long_array (2 );
290-
291- if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
292- tuple_result->pin ()[0 ] = static_cast <jint>(Error::NotSupported);
293- return tuple_result;
294- }
295-
296- auto image_size = image->size ();
297285 std::vector<llm::Image> images;
286+ auto image_size = image->size ();
298287 if (image_size != 0 ) {
299288 std::vector<jint> image_data_jint (image_size);
300289 std::vector<uint8_t > image_data (image_size);
@@ -303,13 +292,14 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
303292 image_data[i] = image_data_jint[i];
304293 }
305294 llm::Image image_runner{image_data, width, height, channels};
306- images.push_back (image_runner);
295+ prefill_inputs_.emplace_back (
296+ llm::MultimodalInput{std::move (image_runner)});
307297 }
308- // TODO(hsz): make start_pos a reference and update it here
309- jint result = static_cast <jint>(
310- multi_modal_runner_-> prefill_images (images, start_pos) );
311- tuple_result-> pin ()[ 0 ] = result;
312- tuple_result->pin ()[1 ] = static_cast <jlong>(start_pos );
298+
299+ facebook::jni::local_ref<jlongArray> tuple_result =
300+ facebook::jni::make_long_array ( 2 );
301+
302+ tuple_result->pin ()[0 ] = static_cast <jint>(Error::Ok );
313303 return tuple_result;
314304 }
315305
@@ -320,13 +310,15 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
320310 facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
321311 jboolean echo) {
322312 if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
323- return static_cast <jint>(multi_modal_runner_->generate_from_pos (
324- prompt->toStdString (),
325- seq_len,
326- start_pos,
313+ std::vector<llm::MultimodalInput> inputs = prefill_inputs_;
314+ prefill_inputs_.clear ();
315+ inputs.emplace_back (llm::MultimodalInput{prompt->toStdString ()});
316+ return static_cast <jint>(multi_modal_runner_->generate (
317+ inputs,
318+ llm::GenerationConfig{
319+ .echo = static_cast <bool >(echo), .seq_len = seq_len},
327320 [callback](const std::string& result) { callback->onResult (result); },
328- [callback](const llm::Stats& stats) { callback->onStats (stats); },
329- echo));
321+ [callback](const llm::Stats& stats) { callback->onStats (stats); }));
330322 } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
331323 executorch::extension::llm::GenerationConfig config{
332324 .echo = static_cast <bool >(echo),
0 commit comments