Skip to content

Commit 392c157

Browse files
committed
Use prefill API
1 parent 3a4ffce commit 392c157

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

extension/android/jni/jni_layer_llama.cpp

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -245,16 +245,23 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
245245

246246
// Returns status_code
247247
// Contract is valid within an AAR (JNI + corresponding Java code)
248-
jint append_text_input(facebook::jni::alias_ref<jstring> prompt) {
249-
prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()});
250-
return 0;
248+
jint prefill_text_input(facebook::jni::alias_ref<jstring> prompt) {
249+
if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
250+
prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()});
251+
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
252+
multi_modal_runner_->prefill(llm::MultimodalInput{prompt->toStdString()});
253+
return 0;
254+
}
251255
}
252256

253-
jint append_images_input(
257+
jint prefill_images_input(
254258
facebook::jni::alias_ref<jintArray> image,
255259
jint width,
256260
jint height,
257261
jint channels) {
262+
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
263+
return Error::InvalidArgument;
264+
}
258265
if (image == nullptr) {
259266
return Error::InvalidArgument;
260267
}
@@ -271,18 +278,20 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
271278
image_data[i] = image_data_jint[i];
272279
}
273280
llm::Image image_runner{std::move(image_data), width, height, channels};
274-
prefill_inputs_.emplace_back(
275-
llm::MultimodalInput{std::move(image_runner)});
281+
multi_modal_runner_->prefill(llm::MultimodalInput{std::move(image_runner)});
276282
}
277283

278284
return 0;
279285
}
280286

281-
jint append_audio_input(
287+
jint prefill_audio_input(
282288
facebook::jni::alias_ref<jintArray> audio,
283289
jint batch_size,
284290
jint n_channels,
285291
jint n_samples) {
292+
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
293+
return Error::InvalidArgument;
294+
}
286295
if (audio == nullptr) {
287296
return Error::InvalidArgument;
288297
}
@@ -295,8 +304,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
295304
audio_data[i] = audio_data_jint[i];
296305
}
297306
llm::RawAudio audio_input{audio_data, batch_size, n_channels, n_samples};
298-
prefill_inputs_.emplace_back(
299-
llm::MultimodalInput{std::move(audio_input)});
307+
multi_modal_runner_->prefill(llm::MultimodalInput{std::move(audio_input)});
300308
}
301309
return 0;
302310
}
@@ -334,11 +342,11 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
334342
makeNativeMethod("stop", ExecuTorchLlmJni::stop),
335343
makeNativeMethod("load", ExecuTorchLlmJni::load),
336344
makeNativeMethod(
337-
"appendImagesInput", ExecuTorchLlmJni::append_images_input),
345+
"appendImagesInput", ExecuTorchLlmJni::prefill_images_input),
338346
makeNativeMethod(
339-
"appendTextInput", ExecuTorchLlmJni::append_text_input),
347+
"appendTextInput", ExecuTorchLlmJni::prefill_text_input),
340348
makeNativeMethod(
341-
"appendAudioInput", ExecuTorchLlmJni::append_audio_input),
349+
"appendAudioInput", ExecuTorchLlmJni::prefill_audio_input),
342350
makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context),
343351
});
344352
}

0 commit comments

Comments
 (0)