@@ -121,7 +121,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
121121 float temperature_ = 0 .0f ;
122122 int model_type_category_;
123123 std::unique_ptr<llm::IRunner> runner_;
124- std::unique_ptr<executorch::extension::llm::MultimodalRunner> multi_modal_runner_;
124+ std::unique_ptr<executorch::extension::llm::MultimodalRunner>
125+ multi_modal_runner_;
126+ std::vector<llm::MultimodalInput> prefill_inputs_;
125127
126128 public:
127129 constexpr static auto kJavaDescriptor =
@@ -215,6 +217,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
215217 facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
216218 jboolean echo) {
217219 if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
220+ std::vector<llm::MultimodalInput> inputs = prefill_inputs_;
221+ prefill_inputs_.clear ();
222+ inputs.emplace_back (llm::MultimodalInput{prompt->toStdString ()});
218223 auto image_size = image->size ();
219224 std::vector<llm::Image> images;
220225 if (image_size != 0 ) {
@@ -225,15 +230,18 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
225230 image_data[i] = image_data_jint[i];
226231 }
227232 llm::Image image_runner{image_data, width, height, channels};
228- images. push_back ( image_runner);
233+ inputs. emplace_back (llm::MultimodalInput{ std::move ( image_runner)} );
229234 }
235+ executorch::extension::llm::GenerationConfig config{
236+ .echo = static_cast <bool >(echo),
237+ .seq_len = seq_len,
238+ .temperature = temperature_,
239+ };
230240 multi_modal_runner_->generate (
231- std::move (images),
232- prompt->toStdString (),
233- seq_len,
234- [callback](std::string result) { callback->onResult (result); },
235- [callback](const llm::Stats& result) { callback->onStats (result); },
236- echo);
241+ std::move (inputs),
242+ config,
243+ [callback](const std::string& result) { callback->onResult (result); },
244+ [callback](const llm::Stats& result) { callback->onStats (result); });
237245 } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
238246 executorch::extension::llm::GenerationConfig config{
239247 .echo = static_cast <bool >(echo),
@@ -257,9 +265,10 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
257265 jlong start_pos,
258266 jint bos,
259267 jint eos) {
268+ prefill_inputs_.emplace_back (llm::MultimodalInput{prompt->toStdString ()});
260269 facebook::jni::local_ref<jlongArray> tuple_result =
261270 facebook::jni::make_long_array (2 );
262- tuple_result->pin ()[0 ] = static_cast <jint>(Error::NotSupported );
271+ tuple_result->pin ()[0 ] = static_cast <jint>(Error::Ok );
263272 return tuple_result;
264273 }
265274
@@ -273,10 +282,24 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
273282 jint height,
274283 jint channels,
275284 jlong start_pos) {
285+ std::vector<llm::Image> images;
286+ auto image_size = image->size ();
287+ if (image_size != 0 ) {
288+ std::vector<jint> image_data_jint (image_size);
289+ std::vector<uint8_t > image_data (image_size);
290+ image->getRegion (0 , image_size, image_data_jint.data ());
291+ for (int i = 0 ; i < image_size; i++) {
292+ image_data[i] = image_data_jint[i];
293+ }
294+ llm::Image image_runner{image_data, width, height, channels};
295+ prefill_inputs_.emplace_back (
296+ llm::MultimodalInput{std::move (image_runner)});
297+ }
298+
276299 facebook::jni::local_ref<jlongArray> tuple_result =
277300 facebook::jni::make_long_array (2 );
278301
279- tuple_result->pin ()[0 ] = static_cast <jint>(Error::NotSupported );
302+ tuple_result->pin ()[0 ] = static_cast <jint>(Error::Ok );
280303 return tuple_result;
281304 }
282305
@@ -287,10 +310,13 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
287310 facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
288311 jboolean echo) {
289312 if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
290-
313+ std::vector<llm::MultimodalInput> inputs = prefill_inputs_;
314+ prefill_inputs_.clear ();
315+ inputs.emplace_back (llm::MultimodalInput{prompt->toStdString ()});
291316 return static_cast <jint>(multi_modal_runner_->generate (
292- std::vector<llm::MultimodalInput>{llm::MultimodalInput{prompt->toStdString ()}},
293- llm::GenerationConfig {.echo = static_cast <bool >(echo), .seq_len = seq_len},
317+ inputs,
318+ llm::GenerationConfig{
319+ .echo = static_cast <bool >(echo), .seq_len = seq_len},
294320 [callback](const std::string& result) { callback->onResult (result); },
295321 [callback](const llm::Stats& stats) { callback->onStats (stats); }));
296322 } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
0 commit comments