diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index b014ceb75d8..d8ee0ab7482 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -177,7 +177,7 @@ public native int generate( * @throws RuntimeException if the prefill failed */ public long prefillImages(int[] image, int width, int height, int channels, long startPos) { - long[] nativeResult = prefillImagesNative(image, width, height, channels, startPos); + long[] nativeResult = addImageInputNative(image, width, height, channels, startPos); if (nativeResult[0] != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]); } @@ -185,7 +185,7 @@ public long prefillImages(int[] image, int width, int height, int channels, long } // returns a tuple of (status, updated startPos) - private native long[] prefillImagesNative( + private native long[] addImageInputNative( int[] image, int width, int height, int channels, long startPos); /** @@ -200,7 +200,7 @@ private native long[] prefillImagesNative( * @throws RuntimeException if the prefill failed */ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { - long[] nativeResult = prefillPromptNative(prompt, startPos, bos, eos); + long[] nativeResult = addTextInputNative(prompt, startPos, bos, eos); if (nativeResult[0] != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]); } @@ -208,7 +208,10 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { } // returns a tuple of (status, updated startPos) - private native long[] prefillPromptNative(String prompt, long startPos, int bos, int eos); + private native long[] addTextInputNative(String prompt, long startPos, int bos, int eos); + + // returns the status code + private native int addAudioInputNative(int[] audio, int batch_size, int n_bins, int n_frames); /** * Generate tokens from the given prompt, starting from the given position. @@ -217,6 +220,12 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { * @param seqLen The total sequence length, including the prompt tokens and new tokens. * @param startPos The starting position in KV cache of the input in the LLM. * @param callback callback object to receive results. + * @param echo indicate whether to echo the + *

/** Generate tokens from the given prompt, starting from the given position. + * @param prompt The text prompt to LLaVA. + * @param seqLen The total sequence length, including the prompt tokens and new tokens. + * @param startPos The starting position in KV cache of the input in the LLM. + * @param callback callback object to receive results. * @param echo indicate whether to echo the input prompt or not. * @return The error code. */ diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 886b25e4221..aa5f6052225 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -13,7 +13,6 @@ #include #include -#include #include #include #include @@ -122,7 +121,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { float temperature_ = 0.0f; int model_type_category_; std::unique_ptr runner_; - std::unique_ptr multi_modal_runner_; + std::unique_ptr + multi_modal_runner_; + std::vector prefill_inputs_; public: constexpr static auto kJavaDescriptor = @@ -168,10 +169,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { model_type_category_ = model_type_category; if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) { - multi_modal_runner_ = std::make_unique( + multi_modal_runner_ = llm::create_multimodal_runner( model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), - temperature); + llm::load_tokenizer(tokenizer_path->toStdString())); } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) { std::optional data_path_str = data_path ? std::optional{data_path->toStdString()} @@ -217,6 +217,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { facebook::jni::alias_ref callback, jboolean echo) { if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { + std::vector inputs = prefill_inputs_; + prefill_inputs_.clear(); + inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()}); auto image_size = image->size(); std::vector images; if (image_size != 0) { @@ -227,15 +230,18 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { image_data[i] = image_data_jint[i]; } llm::Image image_runner{image_data, width, height, channels}; - images.push_back(image_runner); + inputs.emplace_back(llm::MultimodalInput{std::move(image_runner)}); } + executorch::extension::llm::GenerationConfig config{ + .echo = static_cast(echo), + .seq_len = seq_len, + .temperature = temperature_, + }; multi_modal_runner_->generate( - std::move(images), - prompt->toStdString(), - seq_len, - [callback](std::string result) { callback->onResult(result); }, - [callback](const llm::Stats& result) { callback->onStats(result); }, - echo); + std::move(inputs), + config, + [callback](const std::string& result) { callback->onResult(result); }, + [callback](const llm::Stats& result) { callback->onStats(result); }); } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { executorch::extension::llm::GenerationConfig config{ .echo = static_cast(echo), @@ -254,24 +260,15 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { // Returns a tuple of (error, start_pos) // Contract is valid within an AAR (JNI + corresponding Java code) // If the first element is not Error::Ok, the other element is undefined. - facebook::jni::local_ref prefill_prompt( + facebook::jni::local_ref add_text_input( facebook::jni::alias_ref prompt, jlong start_pos, jint bos, jint eos) { + prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()}); facebook::jni::local_ref tuple_result = facebook::jni::make_long_array(2); - if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) { - tuple_result->pin()[0] = static_cast(Error::NotSupported); - return tuple_result; - } - - auto&& result = multi_modal_runner_->prefill_prompt( - prompt->toStdString(), start_pos, bos, eos); tuple_result->pin()[0] = static_cast(Error::Ok); - if (result.ok()) { - tuple_result->pin()[1] = static_cast(start_pos); - } return tuple_result; } @@ -279,22 +276,14 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { // Contract is valid within an AAR (JNI + corresponding Java code) // If the first element is not Error::Ok, the other element is undefined. - facebook::jni::local_ref prefill_images( + facebook::jni::local_ref add_images_input( facebook::jni::alias_ref image, jint width, jint height, jint channels, jlong start_pos) { - facebook::jni::local_ref tuple_result = - facebook::jni::make_long_array(2); - - if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) { - tuple_result->pin()[0] = static_cast(Error::NotSupported); - return tuple_result; - } - - auto image_size = image->size(); std::vector images; + auto image_size = image->size(); if (image_size != 0) { std::vector image_data_jint(image_size); std::vector image_data(image_size); @@ -303,16 +292,39 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { image_data[i] = image_data_jint[i]; } llm::Image image_runner{image_data, width, height, channels}; - images.push_back(image_runner); + prefill_inputs_.emplace_back( + llm::MultimodalInput{std::move(image_runner)}); } - // TODO(hsz): make start_pos a reference and update it here - jint result = static_cast( - multi_modal_runner_->prefill_images(images, start_pos)); - tuple_result->pin()[0] = result; - tuple_result->pin()[1] = static_cast(start_pos); + + facebook::jni::local_ref tuple_result = + facebook::jni::make_long_array(2); + + tuple_result->pin()[0] = static_cast(Error::Ok); return tuple_result; } + // Returns the status code + jint add_audio_input( + facebook::jni::alias_ref audio, + jint batch_size, + jint n_bins, + jint n_frames) { + auto audio_size = audio->size(); + if (audio_size != 0) { + std::vector audio_data_jint(audio_size); + std::vector audio_data(audio_size); + audio->getRegion(0, audio_size, audio_data_jint.data()); + for (int i = 0; i < audio_size; i++) { + audio_data[i] = audio_data_jint[i]; + } + auto&& audio_input = llm::make_audio_input( + llm::Audio{audio_data, batch_size, n_bins, n_frames}); + prefill_inputs_.emplace_back(audio_input); + } + + return 0; + } + jint generate_from_pos( facebook::jni::alias_ref prompt, jint seq_len, @@ -320,13 +332,15 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { facebook::jni::alias_ref callback, jboolean echo) { if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { - return static_cast(multi_modal_runner_->generate_from_pos( - prompt->toStdString(), - seq_len, - start_pos, + std::vector inputs = prefill_inputs_; + prefill_inputs_.clear(); + inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()}); + return static_cast(multi_modal_runner_->generate( + inputs, + llm::GenerationConfig{ + .echo = static_cast(echo), .seq_len = seq_len}, [callback](const std::string& result) { callback->onResult(result); }, - [callback](const llm::Stats& stats) { callback->onStats(stats); }, - echo)); + [callback](const llm::Stats& stats) { callback->onStats(stats); })); } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { executorch::extension::llm::GenerationConfig config{ .echo = static_cast(echo), @@ -367,9 +381,11 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { makeNativeMethod("stop", ExecuTorchLlmJni::stop), makeNativeMethod("load", ExecuTorchLlmJni::load), makeNativeMethod( - "prefillImagesNative", ExecuTorchLlmJni::prefill_images), + "addImageInputNative", ExecuTorchLlmJni::add_images_input), + makeNativeMethod( + "addTextInputNative", ExecuTorchLlmJni::add_text_input), makeNativeMethod( - "prefillPromptNative", ExecuTorchLlmJni::prefill_prompt), + "addAudioInputNative", ExecuTorchLlmJni::add_audio_input), makeNativeMethod( "generateFromPos", ExecuTorchLlmJni::generate_from_pos), });