@@ -245,16 +245,23 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
245
245
246
246
// Returns status_code
247
247
// Contract is valid within an AAR (JNI + corresponding Java code)
248
- jint append_text_input (facebook::jni::alias_ref<jstring> prompt) {
249
- prefill_inputs_.emplace_back (llm::MultimodalInput{prompt->toStdString ()});
250
- return 0 ;
248
+ jint prefill_text_input (facebook::jni::alias_ref<jstring> prompt) {
249
+ if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
250
+ prefill_inputs_.emplace_back (llm::MultimodalInput{prompt->toStdString ()});
251
+ } else if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
252
+ multi_modal_runner_->prefill (llm::MultimodalInput{prompt->toStdString ()});
253
+ return 0 ;
254
+ }
251
255
}
252
256
253
- jint append_images_input (
257
+ jint prefill_images_input (
254
258
facebook::jni::alias_ref<jintArray> image,
255
259
jint width,
256
260
jint height,
257
261
jint channels) {
262
+ if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
263
+ return Error::InvalidArgument;
264
+ }
258
265
if (image == nullptr ) {
259
266
return Error::InvalidArgument;
260
267
}
@@ -271,18 +278,20 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
271
278
image_data[i] = image_data_jint[i];
272
279
}
273
280
llm::Image image_runner{std::move (image_data), width, height, channels};
274
- prefill_inputs_.emplace_back (
275
- llm::MultimodalInput{std::move (image_runner)});
281
+ multi_modal_runner_->prefill (llm::MultimodalInput{std::move (image_runner)});
276
282
}
277
283
278
284
return 0 ;
279
285
}
280
286
281
- jint append_audio_input (
287
+ jint prefill_audio_input (
282
288
facebook::jni::alias_ref<jintArray> audio,
283
289
jint batch_size,
284
290
jint n_channels,
285
291
jint n_samples) {
292
+ if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
293
+ return Error::InvalidArgument;
294
+ }
286
295
if (audio == nullptr ) {
287
296
return Error::InvalidArgument;
288
297
}
@@ -295,8 +304,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
295
304
audio_data[i] = audio_data_jint[i];
296
305
}
297
306
llm::RawAudio audio_input{audio_data, batch_size, n_channels, n_samples};
298
- prefill_inputs_.emplace_back (
299
- llm::MultimodalInput{std::move (audio_input)});
307
+ multi_modal_runner_->prefill (llm::MultimodalInput{std::move (audio_input)});
300
308
}
301
309
return 0 ;
302
310
}
@@ -334,11 +342,11 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
334
342
makeNativeMethod (" stop" , ExecuTorchLlmJni::stop),
335
343
makeNativeMethod (" load" , ExecuTorchLlmJni::load),
336
344
makeNativeMethod (
337
- " appendImagesInput" , ExecuTorchLlmJni::append_images_input ),
345
+ " appendImagesInput" , ExecuTorchLlmJni::prefill_images_input ),
338
346
makeNativeMethod (
339
- " appendTextInput" , ExecuTorchLlmJni::append_text_input ),
347
+ " appendTextInput" , ExecuTorchLlmJni::prefill_text_input ),
340
348
makeNativeMethod (
341
- " appendAudioInput" , ExecuTorchLlmJni::append_audio_input ),
349
+ " appendAudioInput" , ExecuTorchLlmJni::prefill_audio_input ),
342
350
makeNativeMethod (" resetContext" , ExecuTorchLlmJni::reset_context),
343
351
});
344
352
}
0 commit comments