diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java
index b014ceb75d8..d8ee0ab7482 100644
--- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java
+++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java
@@ -177,7 +177,7 @@ public native int generate(
* @throws RuntimeException if the prefill failed
*/
public long prefillImages(int[] image, int width, int height, int channels, long startPos) {
- long[] nativeResult = prefillImagesNative(image, width, height, channels, startPos);
+ long[] nativeResult = addImageInputNative(image, width, height, channels, startPos);
if (nativeResult[0] != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]);
}
@@ -185,7 +185,7 @@ public long prefillImages(int[] image, int width, int height, int channels, long
}
// returns a tuple of (status, updated startPos)
- private native long[] prefillImagesNative(
+ private native long[] addImageInputNative(
int[] image, int width, int height, int channels, long startPos);
/**
@@ -200,7 +200,7 @@ private native long[] prefillImagesNative(
* @throws RuntimeException if the prefill failed
*/
public long prefillPrompt(String prompt, long startPos, int bos, int eos) {
- long[] nativeResult = prefillPromptNative(prompt, startPos, bos, eos);
+ long[] nativeResult = addTextInputNative(prompt, startPos, bos, eos);
if (nativeResult[0] != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]);
}
@@ -208,7 +208,10 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) {
}
// returns a tuple of (status, updated startPos)
- private native long[] prefillPromptNative(String prompt, long startPos, int bos, int eos);
+ private native long[] addTextInputNative(String prompt, long startPos, int bos, int eos);
+
+ // returns the status code
+ private native int addAudioInputNative(int[] audio, int batch_size, int n_bins, int n_frames);
/**
* Generate tokens from the given prompt, starting from the given position.
@@ -217,6 +220,12 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) {
* @param seqLen The total sequence length, including the prompt tokens and new tokens.
* @param startPos The starting position in KV cache of the input in the LLM.
* @param callback callback object to receive results.
+ * @param echo indicate whether to echo the
+ *
/** Generate tokens from the given prompt, starting from the given position.
+ * @param prompt The text prompt to LLaVA.
+ * @param seqLen The total sequence length, including the prompt tokens and new tokens.
+ * @param startPos The starting position in KV cache of the input in the LLM.
+ * @param callback callback object to receive results.
* @param echo indicate whether to echo the input prompt or not.
* @return The error code.
*/
diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp
index 886b25e4221..aa5f6052225 100644
--- a/extension/android/jni/jni_layer_llama.cpp
+++ b/extension/android/jni/jni_layer_llama.cpp
@@ -13,7 +13,6 @@
#include
#include
-#include
#include
#include
#include
@@ -122,7 +121,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass {
float temperature_ = 0.0f;
int model_type_category_;
std::unique_ptr runner_;
- std::unique_ptr multi_modal_runner_;
+ std::unique_ptr
+ multi_modal_runner_;
+ std::vector prefill_inputs_;
public:
constexpr static auto kJavaDescriptor =
@@ -168,10 +169,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass {
model_type_category_ = model_type_category;
if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) {
- multi_modal_runner_ = std::make_unique(
+ multi_modal_runner_ = llm::create_multimodal_runner(
model_path->toStdString().c_str(),
- tokenizer_path->toStdString().c_str(),
- temperature);
+ llm::load_tokenizer(tokenizer_path->toStdString()));
} else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
std::optional data_path_str = data_path
? std::optional{data_path->toStdString()}
@@ -217,6 +217,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass {
facebook::jni::alias_ref callback,
jboolean echo) {
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
+ std::vector inputs = prefill_inputs_;
+ prefill_inputs_.clear();
+ inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()});
auto image_size = image->size();
std::vector images;
if (image_size != 0) {
@@ -227,15 +230,18 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass {
image_data[i] = image_data_jint[i];
}
llm::Image image_runner{image_data, width, height, channels};
- images.push_back(image_runner);
+ inputs.emplace_back(llm::MultimodalInput{std::move(image_runner)});
}
+ executorch::extension::llm::GenerationConfig config{
+ .echo = static_cast(echo),
+ .seq_len = seq_len,
+ .temperature = temperature_,
+ };
multi_modal_runner_->generate(
- std::move(images),
- prompt->toStdString(),
- seq_len,
- [callback](std::string result) { callback->onResult(result); },
- [callback](const llm::Stats& result) { callback->onStats(result); },
- echo);
+ std::move(inputs),
+ config,
+ [callback](const std::string& result) { callback->onResult(result); },
+ [callback](const llm::Stats& result) { callback->onStats(result); });
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
executorch::extension::llm::GenerationConfig config{
.echo = static_cast(echo),
@@ -254,24 +260,15 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass {
// Returns a tuple of (error, start_pos)
// Contract is valid within an AAR (JNI + corresponding Java code)
// If the first element is not Error::Ok, the other element is undefined.
- facebook::jni::local_ref prefill_prompt(
+ facebook::jni::local_ref add_text_input(
facebook::jni::alias_ref prompt,
jlong start_pos,
jint bos,
jint eos) {
+ prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()});
facebook::jni::local_ref tuple_result =
facebook::jni::make_long_array(2);
- if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
- tuple_result->pin()[0] = static_cast(Error::NotSupported);
- return tuple_result;
- }
-
- auto&& result = multi_modal_runner_->prefill_prompt(
- prompt->toStdString(), start_pos, bos, eos);
tuple_result->pin()[0] = static_cast(Error::Ok);
- if (result.ok()) {
- tuple_result->pin()[1] = static_cast(start_pos);
- }
return tuple_result;
}
@@ -279,22 +276,14 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass {
// Contract is valid within an AAR (JNI + corresponding Java code)
// If the first element is not Error::Ok, the other element is undefined.
- facebook::jni::local_ref prefill_images(
+ facebook::jni::local_ref add_images_input(
facebook::jni::alias_ref image,
jint width,
jint height,
jint channels,
jlong start_pos) {
- facebook::jni::local_ref tuple_result =
- facebook::jni::make_long_array(2);
-
- if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
- tuple_result->pin()[0] = static_cast(Error::NotSupported);
- return tuple_result;
- }
-
- auto image_size = image->size();
std::vector images;
+ auto image_size = image->size();
if (image_size != 0) {
std::vector image_data_jint(image_size);
std::vector image_data(image_size);
@@ -303,16 +292,39 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass {
image_data[i] = image_data_jint[i];
}
llm::Image image_runner{image_data, width, height, channels};
- images.push_back(image_runner);
+ prefill_inputs_.emplace_back(
+ llm::MultimodalInput{std::move(image_runner)});
}
- // TODO(hsz): make start_pos a reference and update it here
- jint result = static_cast(
- multi_modal_runner_->prefill_images(images, start_pos));
- tuple_result->pin()[0] = result;
- tuple_result->pin()[1] = static_cast(start_pos);
+
+ facebook::jni::local_ref tuple_result =
+ facebook::jni::make_long_array(2);
+
+ tuple_result->pin()[0] = static_cast(Error::Ok);
return tuple_result;
}
+ // Returns the status code
+ jint add_audio_input(
+ facebook::jni::alias_ref audio,
+ jint batch_size,
+ jint n_bins,
+ jint n_frames) {
+ auto audio_size = audio->size();
+ if (audio_size != 0) {
+ std::vector audio_data_jint(audio_size);
+ std::vector audio_data(audio_size);
+ audio->getRegion(0, audio_size, audio_data_jint.data());
+ for (int i = 0; i < audio_size; i++) {
+ audio_data[i] = audio_data_jint[i];
+ }
+ auto&& audio_input = llm::make_audio_input(
+ llm::Audio{audio_data, batch_size, n_bins, n_frames});
+ prefill_inputs_.emplace_back(audio_input);
+ }
+
+ return 0;
+ }
+
jint generate_from_pos(
facebook::jni::alias_ref prompt,
jint seq_len,
@@ -320,13 +332,15 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass {
facebook::jni::alias_ref callback,
jboolean echo) {
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
- return static_cast(multi_modal_runner_->generate_from_pos(
- prompt->toStdString(),
- seq_len,
- start_pos,
+ std::vector inputs = prefill_inputs_;
+ prefill_inputs_.clear();
+ inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()});
+ return static_cast(multi_modal_runner_->generate(
+ inputs,
+ llm::GenerationConfig{
+ .echo = static_cast(echo), .seq_len = seq_len},
[callback](const std::string& result) { callback->onResult(result); },
- [callback](const llm::Stats& stats) { callback->onStats(stats); },
- echo));
+ [callback](const llm::Stats& stats) { callback->onStats(stats); }));
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
executorch::extension::llm::GenerationConfig config{
.echo = static_cast(echo),
@@ -367,9 +381,11 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass {
makeNativeMethod("stop", ExecuTorchLlmJni::stop),
makeNativeMethod("load", ExecuTorchLlmJni::load),
makeNativeMethod(
- "prefillImagesNative", ExecuTorchLlmJni::prefill_images),
+ "addImageInputNative", ExecuTorchLlmJni::add_images_input),
+ makeNativeMethod(
+ "addTextInputNative", ExecuTorchLlmJni::add_text_input),
makeNativeMethod(
- "prefillPromptNative", ExecuTorchLlmJni::prefill_prompt),
+ "addAudioInputNative", ExecuTorchLlmJni::add_audio_input),
makeNativeMethod(
"generateFromPos", ExecuTorchLlmJni::generate_from_pos),
});