From efe81eebd5671b355cbf1d6cd8967641481e15c4 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 8 Sep 2025 21:05:55 -0700 Subject: [PATCH 01/14] LlmModule prefill refactor --- .../executorch/extension/llm/LlmModule.java | 46 +++++++++++++------ extension/android/jni/jni_layer_llama.cpp | 31 +++++-------- 2 files changed, 43 insertions(+), 34 deletions(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index b014ceb75d8..7c35dbf2989 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -173,20 +173,23 @@ public native int generate( * @param height Input image height * @param channels Input image number of channels * @param startPos The starting position in KV cache of the input in the LLM. - * @return The updated starting position in KV cache of the input in the LLM. + * @return 0, as the updated starting position in KV cache of the input in the LLM is no longer + * exposed to user. * @throws RuntimeException if the prefill failed */ + @Deprecated public long prefillImages(int[] image, int width, int height, int channels, long startPos) { - long[] nativeResult = prefillImagesNative(image, width, height, channels, startPos); - if (nativeResult[0] != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]); + if (startPos == 0) { + resetContext(); } - return nativeResult[1]; + int nativeResult = prefillImagesNative(image, width, height, channels); + if (nativeResult != 0) { + throw new RuntimeException("Prefill failed with error code: " + nativeResult); + } + return 0; } - // returns a tuple of (status, updated startPos) - private native long[] prefillImagesNative( - int[] image, int width, int height, int channels, long startPos); + private native int prefillImagesNative(int[] image, int width, int height, int channels); /** * Prefill an LLaVA Module with the given text input. @@ -196,23 +199,30 @@ private native long[] prefillImagesNative( * reference and will be updated inside this function. * @param bos The number of BOS (begin of sequence) token. * @param eos The number of EOS (end of sequence) token. - * @return The updated starting position in KV cache of the input in the LLM. + * @return 0, as the updated starting position in KV cache of the input in the LLM is no longer + * exposed to user. * @throws RuntimeException if the prefill failed */ + @Deprecated public long prefillPrompt(String prompt, long startPos, int bos, int eos) { - long[] nativeResult = prefillPromptNative(prompt, startPos, bos, eos); - if (nativeResult[0] != 0) { - throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]); + if (startPos == 0) { + resetContext(); } - return nativeResult[1]; + int nativeResult = prefillPromptNative(prompt, bos, eos); + if (nativeResult != 0) { + throw new RuntimeException("Prefill failed with error code: " + nativeResult); + } + return 0; } // returns a tuple of (status, updated startPos) - private native long[] prefillPromptNative(String prompt, long startPos, int bos, int eos); + private native int prefillPromptNative(String prompt, int bos, int eos); /** * Generate tokens from the given prompt, starting from the given position. * + *

This is a deprecated API. Please use {@link #generate(String, int, LlmCallback, boolean)} + * * @param prompt The text prompt to LLaVA. * @param seqLen The total sequence length, including the prompt tokens and new tokens. * @param startPos The starting position in KV cache of the input in the LLM. @@ -220,9 +230,17 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { * @param echo indicate whether to echo the input prompt or not. * @return The error code. */ + @Deprecated public native int generateFromPos( String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo); + /** + * Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM. + * + *

The startPos will be reset to 0. + */ + public native void resetContext(); + /** Stop current generate() before it finishes. */ @DoNotStrip public native void stop(); diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 0c3550f151a..b4a9320e20b 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -260,28 +260,19 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { // Returns a tuple of (error, start_pos) // Contract is valid within an AAR (JNI + corresponding Java code) // If the first element is not Error::Ok, the other element is undefined. - facebook::jni::local_ref prefill_prompt( + jint prefill_prompt( facebook::jni::alias_ref prompt, - jlong start_pos, jint bos, jint eos) { prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()}); - facebook::jni::local_ref tuple_result = - facebook::jni::make_long_array(2); - tuple_result->pin()[0] = static_cast(Error::Ok); - return tuple_result; + return 0; } - // Returns a tuple of (error, start_pos) - // Contract is valid within an AAR (JNI + corresponding Java code) - // If the first element is not Error::Ok, the other element is undefined. - - facebook::jni::local_ref prefill_images( + jint prefill_images( facebook::jni::alias_ref image, jint width, jint height, - jint channels, - jlong start_pos) { + jint channels) { std::vector images; auto image_size = image->size(); if (image_size != 0) { @@ -296,11 +287,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { llm::MultimodalInput{std::move(image_runner)}); } - facebook::jni::local_ref tuple_result = - facebook::jni::make_long_array(2); - - tuple_result->pin()[0] = static_cast(Error::Ok); - return tuple_result; + return 0; } jint generate_from_pos( @@ -325,9 +312,8 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { .seq_len = seq_len, .temperature = temperature_, }; - return static_cast(runner_->generate_from_pos( + return static_cast(runner_->generate( prompt->toStdString(), - start_pos, config, [callback](std::string result) { callback->onResult(result); }, [callback](const llm::Stats& stats) { callback->onStats(stats); })); @@ -343,6 +329,10 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { } } + void reset_context() { + runner_->reset(); + } + jint load() { if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { return static_cast(multi_modal_runner_->load()); @@ -364,6 +354,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { "prefillPromptNative", ExecuTorchLlmJni::prefill_prompt), makeNativeMethod( "generateFromPos", ExecuTorchLlmJni::generate_from_pos), + makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context), }); } }; From 015a6ab78f7f9c80e1d3970b83b583d6de9f8557 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 8 Sep 2025 21:11:35 -0700 Subject: [PATCH 02/14] Doing some rename --- .../org/pytorch/executorch/extension/llm/LlmModule.java | 8 ++++---- extension/android/jni/jni_layer_llama.cpp | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index 7c35dbf2989..e4be53f65cd 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -182,14 +182,14 @@ public long prefillImages(int[] image, int width, int height, int channels, long if (startPos == 0) { resetContext(); } - int nativeResult = prefillImagesNative(image, width, height, channels); + int nativeResult = appendImagesInput(image, width, height, channels); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); } return 0; } - private native int prefillImagesNative(int[] image, int width, int height, int channels); + private native int appendImagesInput(int[] image, int width, int height, int channels); /** * Prefill an LLaVA Module with the given text input. @@ -208,7 +208,7 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { if (startPos == 0) { resetContext(); } - int nativeResult = prefillPromptNative(prompt, bos, eos); + int nativeResult = appendTextInput(prompt, bos, eos); if (nativeResult != 0) { throw new RuntimeException("Prefill failed with error code: " + nativeResult); } @@ -216,7 +216,7 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { } // returns a tuple of (status, updated startPos) - private native int prefillPromptNative(String prompt, int bos, int eos); + private native int appendTextInput(String prompt, int bos, int eos); /** * Generate tokens from the given prompt, starting from the given position. diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index b4a9320e20b..85d97ed2797 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -260,7 +260,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { // Returns a tuple of (error, start_pos) // Contract is valid within an AAR (JNI + corresponding Java code) // If the first element is not Error::Ok, the other element is undefined. - jint prefill_prompt( + jint append_text_input( facebook::jni::alias_ref prompt, jint bos, jint eos) { @@ -268,7 +268,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { return 0; } - jint prefill_images( + jint append_images_input( facebook::jni::alias_ref image, jint width, jint height, @@ -349,9 +349,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { makeNativeMethod("stop", ExecuTorchLlmJni::stop), makeNativeMethod("load", ExecuTorchLlmJni::load), makeNativeMethod( - "prefillImagesNative", ExecuTorchLlmJni::prefill_images), + "appendImagesInput", ExecuTorchLlmJni::append_images_input), makeNativeMethod( - "prefillPromptNative", ExecuTorchLlmJni::prefill_prompt), + "appendTextInput", ExecuTorchLlmJni::append_text_input), makeNativeMethod( "generateFromPos", ExecuTorchLlmJni::generate_from_pos), makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context), From 16b6d1cb564a3f5cf1fbb06c9ede5bcae19e2117 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 8 Sep 2025 21:14:59 -0700 Subject: [PATCH 03/14] Java layer no longer need a separate generateFromPos --- .../executorch/extension/llm/LlmModule.java | 6 ++-- extension/android/jni/jni_layer_llama.cpp | 33 ------------------- 2 files changed, 4 insertions(+), 35 deletions(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index e4be53f65cd..ec2f38bb7d3 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -231,8 +231,10 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { * @return The error code. */ @Deprecated - public native int generateFromPos( - String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo); + public int generateFromPos( + String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo) { + return generate(prompt, seqLen, callback, echo); + } /** * Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM. diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 85d97ed2797..7cb827bf827 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -290,37 +290,6 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { return 0; } - jint generate_from_pos( - facebook::jni::alias_ref prompt, - jint seq_len, - jlong start_pos, - facebook::jni::alias_ref callback, - jboolean echo) { - if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { - std::vector inputs = prefill_inputs_; - prefill_inputs_.clear(); - inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()}); - return static_cast(multi_modal_runner_->generate( - inputs, - llm::GenerationConfig{ - .echo = static_cast(echo), .seq_len = seq_len}, - [callback](const std::string& result) { callback->onResult(result); }, - [callback](const llm::Stats& stats) { callback->onStats(stats); })); - } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { - executorch::extension::llm::GenerationConfig config{ - .echo = static_cast(echo), - .seq_len = seq_len, - .temperature = temperature_, - }; - return static_cast(runner_->generate( - prompt->toStdString(), - config, - [callback](std::string result) { callback->onResult(result); }, - [callback](const llm::Stats& stats) { callback->onStats(stats); })); - } - return static_cast(executorch::runtime::Error::InvalidArgument); - } - void stop() { if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { multi_modal_runner_->stop(); @@ -352,8 +321,6 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { "appendImagesInput", ExecuTorchLlmJni::append_images_input), makeNativeMethod( "appendTextInput", ExecuTorchLlmJni::append_text_input), - makeNativeMethod( - "generateFromPos", ExecuTorchLlmJni::generate_from_pos), makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context), }); } From e08df474e53c31e0bf7f2f04cd7daed053630075 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 8 Sep 2025 21:17:12 -0700 Subject: [PATCH 04/14] Remove generateFromPos API --- .../executorchllamademo/MainActivity.java | 3 +-- .../executorch/extension/llm/LlmModule.java | 18 ------------------ 2 files changed, 1 insertion(+), 20 deletions(-) diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java index b26031d89a6..fb7cc01206c 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java @@ -778,10 +778,9 @@ public void run() { mCurrentSettingsFields.getModelType(), mCurrentSettingsFields.getBackendType()) == ModelUtils.VISION_MODEL) { - mModule.generateFromPos( + mModule.generate( finalPrompt, ModelUtils.VISION_MODEL_SEQ_LEN, - startPos, MainActivity.this, false); } else if (mCurrentSettingsFields.getModelType() == ModelType.LLAMA_GUARD_3) { diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index ec2f38bb7d3..6599cb4c15d 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -218,24 +218,6 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { // returns a tuple of (status, updated startPos) private native int appendTextInput(String prompt, int bos, int eos); - /** - * Generate tokens from the given prompt, starting from the given position. - * - *

This is a deprecated API. Please use {@link #generate(String, int, LlmCallback, boolean)} - * - * @param prompt The text prompt to LLaVA. - * @param seqLen The total sequence length, including the prompt tokens and new tokens. - * @param startPos The starting position in KV cache of the input in the LLM. - * @param callback callback object to receive results. - * @param echo indicate whether to echo the input prompt or not. - * @return The error code. - */ - @Deprecated - public int generateFromPos( - String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo) { - return generate(prompt, seqLen, callback, echo); - } - /** * Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM. * From e465fa25741991c09b808f6ade78a6c9b058bc10 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Tue, 9 Sep 2025 16:58:13 -0700 Subject: [PATCH 05/14] Add audio input type --- extension/android/jni/jni_layer_llama.cpp | 28 +++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 7cb827bf827..69cef453cf1 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -273,6 +273,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { jint width, jint height, jint channels) { + if (image == nullptr) { + return Error::InvalidArgument; + } std::vector images; auto image_size = image->size(); if (image_size != 0) { @@ -290,6 +293,29 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { return 0; } + jint append_audio_input( + facebook::jni::alias_ref audio, + jint batch_size, + jint n_channels, + jint n_samples) { + if (audio == nullptr) { + return Error::InvalidArgument; + } + auto audio_size = audio->size(); + std::vector audio_data(audio_size); + if (audio_size != 0) { + std::vector audio_data_jint(audio_size); + audio->getRegion(0, audio_size, audio_data_jint.data()); + for (int i = 0; i < audio_size; i++) { + audio_data[i] = audio_data_jint[i]; + } + llm::RawAudio audio_input{audio_data, batch_size, n_channels, n_samples}; + prefill_inputs_.emplace_back( + llm::MultimodalInput{std::move(audio_input)}); + } + return 0; + } + void stop() { if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { multi_modal_runner_->stop(); @@ -321,6 +347,8 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { "appendImagesInput", ExecuTorchLlmJni::append_images_input), makeNativeMethod( "appendTextInput", ExecuTorchLlmJni::append_text_input), + makeNativeMethod( + "appendAudioInput", ExecuTorchLlmJni::append_audio_input), makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context), }); } From d62be5a52191e1e49560b42fdbc38b0eabe42c51 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Tue, 9 Sep 2025 17:06:07 -0700 Subject: [PATCH 06/14] make private method now --- .../java/org/pytorch/executorch/extension/llm/LlmModule.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index 6599cb4c15d..000b8c8555f 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -215,9 +215,12 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { return 0; } - // returns a tuple of (status, updated startPos) + // returns status private native int appendTextInput(String prompt, int bos, int eos); + // returns status + private native int appendAudioInput(int[] data, int batchSize, int nChannels, int nSamples); + /** * Reset the context of the LLM. This will clear the KV cache and reset the state of the LLM. * From 392c15705bd725743c5b73471b1af9347cd1d370 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 22 Sep 2025 15:05:08 -0700 Subject: [PATCH 07/14] Use prefill API --- extension/android/jni/jni_layer_llama.cpp | 32 ++++++++++++++--------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 2f6d80faadd..4faae2d4ae7 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -245,16 +245,23 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { // Returns status_code // Contract is valid within an AAR (JNI + corresponding Java code) - jint append_text_input(facebook::jni::alias_ref prompt) { - prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()}); - return 0; + jint prefill_text_input(facebook::jni::alias_ref prompt) { + if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { + prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()}); + } else if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { + multi_modal_runner_->prefill(llm::MultimodalInput{prompt->toStdString()}); + return 0; + } } - jint append_images_input( + jint prefill_images_input( facebook::jni::alias_ref image, jint width, jint height, jint channels) { + if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) { + return Error::InvalidArgument; + } if (image == nullptr) { return Error::InvalidArgument; } @@ -271,18 +278,20 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { image_data[i] = image_data_jint[i]; } llm::Image image_runner{std::move(image_data), width, height, channels}; - prefill_inputs_.emplace_back( - llm::MultimodalInput{std::move(image_runner)}); + multi_modal_runner_->prefill(llm::MultimodalInput{std::move(image_runner)}); } return 0; } - jint append_audio_input( + jint prefill_audio_input( facebook::jni::alias_ref audio, jint batch_size, jint n_channels, jint n_samples) { + if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) { + return Error::InvalidArgument; + } if (audio == nullptr) { return Error::InvalidArgument; } @@ -295,8 +304,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { audio_data[i] = audio_data_jint[i]; } llm::RawAudio audio_input{audio_data, batch_size, n_channels, n_samples}; - prefill_inputs_.emplace_back( - llm::MultimodalInput{std::move(audio_input)}); + multi_modal_runner_->prefill(llm::MultimodalInput{std::move(audio_input)}); } return 0; } @@ -334,11 +342,11 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { makeNativeMethod("stop", ExecuTorchLlmJni::stop), makeNativeMethod("load", ExecuTorchLlmJni::load), makeNativeMethod( - "appendImagesInput", ExecuTorchLlmJni::append_images_input), + "appendImagesInput", ExecuTorchLlmJni::prefill_images_input), makeNativeMethod( - "appendTextInput", ExecuTorchLlmJni::append_text_input), + "appendTextInput", ExecuTorchLlmJni::prefill_text_input), makeNativeMethod( - "appendAudioInput", ExecuTorchLlmJni::append_audio_input), + "appendAudioInput", ExecuTorchLlmJni::prefill_audio_input), makeNativeMethod("resetContext", ExecuTorchLlmJni::reset_context), }); } From beb17840b98515fa24b9176993b632cf1ad5e195 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 22 Sep 2025 15:19:28 -0700 Subject: [PATCH 08/14] Add a prefill() method for text llm runner --- extension/llm/runner/text_llm_runner.cpp | 22 ++++++++++++++++++++++ extension/llm/runner/text_llm_runner.h | 11 +++++++++++ 2 files changed, 33 insertions(+) diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index 333716ac831..ec9c6c5242f 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -217,6 +217,28 @@ Error TextLLMRunner::generate( return Error::Ok; } +Error TextLLMRunner::prefill( + const std::string& prompt, + const GenerationConfig& config) { + if (!is_loaded()) { + ET_CHECK_OK_OR_RETURN_ERROR(load()); + } + + ::tokenizers::Result> encode_res = tokenizer_->encode( + prompt, + /*bos=*/config.num_bos, + /*eos=*/config.num_eos); + + ET_CHECK_TK_OK_OR_RETURN_ERROR( + encode_res.error(), "Failed to encode prompt %s", prompt.c_str()); + + // encode the (string) prompt into tokens sequence + std::vector prompt_tokens = encode_res.get(); + auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_); + ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); + return Error::Ok; +} + Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) { // Create a GenerationConfig for warmup GenerationConfig config{ diff --git a/extension/llm/runner/text_llm_runner.h b/extension/llm/runner/text_llm_runner.h index 9dd99d82d59..98fcef94f96 100644 --- a/extension/llm/runner/text_llm_runner.h +++ b/extension/llm/runner/text_llm_runner.h @@ -101,6 +101,17 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner { std::function token_callback = {}, std::function stats_callback = {}) override; + /** + * Prefill text inputs, for example to reload chat history. + * @param prompt Text prompt to prefill. + * @param config Configuration parameters for text generation (e.g., + * max_new_tokens, temperature) + * @return The error code. KV cache position is tracked internally in pos_. + */ + ::executorch::runtime::Error prefill( + const std::string& prompt, + const GenerationConfig& config); + /** * @brief Warms up the model with a sample prompt * From 1cd35825b9d8a632ceaad17cccb66dfc0d90fd92 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 22 Sep 2025 15:53:12 -0700 Subject: [PATCH 09/14] Android use new prefill API --- extension/android/jni/jni_layer_llama.cpp | 24 ++++++++++++---------- extension/llm/runner/irunner.h | 11 ++++++++++ extension/llm/runner/multimodal_runner.cpp | 2 +- extension/llm/runner/multimodal_runner.h | 2 +- extension/llm/runner/text_llm_runner.h | 6 +++--- 5 files changed, 29 insertions(+), 16 deletions(-) diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 4faae2d4ae7..68b96cb64ff 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -123,7 +123,6 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { std::unique_ptr runner_; std::unique_ptr multi_modal_runner_; - std::vector prefill_inputs_; public: constexpr static auto kJavaDescriptor = @@ -213,8 +212,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { facebook::jni::alias_ref callback, jboolean echo) { if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { - std::vector inputs = prefill_inputs_; - prefill_inputs_.clear(); + std::vector inputs; if (!prompt->toStdString().empty()) { inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()}); } @@ -247,9 +245,11 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { // Contract is valid within an AAR (JNI + corresponding Java code) jint prefill_text_input(facebook::jni::alias_ref prompt) { if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { - prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()}); + runner_->prefill(prompt->toStdString(), {}); + return 0; } else if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { - multi_modal_runner_->prefill(llm::MultimodalInput{prompt->toStdString()}); + multi_modal_runner_->prefill( + {llm::MultimodalInput{prompt->toStdString()}}); return 0; } } @@ -260,10 +260,10 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { jint height, jint channels) { if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) { - return Error::InvalidArgument; + return static_cast(Error::InvalidArgument); } if (image == nullptr) { - return Error::InvalidArgument; + return static_cast(Error::InvalidArgument); } std::vector images; if (image == nullptr) { @@ -278,7 +278,8 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { image_data[i] = image_data_jint[i]; } llm::Image image_runner{std::move(image_data), width, height, channels}; - multi_modal_runner_->prefill(llm::MultimodalInput{std::move(image_runner)}); + multi_modal_runner_->prefill( + {llm::MultimodalInput{std::move(image_runner)}}); } return 0; @@ -290,10 +291,10 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { jint n_channels, jint n_samples) { if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) { - return Error::InvalidArgument; + return static_cast(Error::InvalidArgument); } if (audio == nullptr) { - return Error::InvalidArgument; + return static_cast(Error::InvalidArgument); } auto audio_size = audio->size(); std::vector audio_data(audio_size); @@ -304,7 +305,8 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { audio_data[i] = audio_data_jint[i]; } llm::RawAudio audio_input{audio_data, batch_size, n_channels, n_samples}; - multi_modal_runner_->prefill(llm::MultimodalInput{std::move(audio_input)}); + multi_modal_runner_->prefill( + {llm::MultimodalInput{std::move(audio_input)}}); } return 0; } diff --git a/extension/llm/runner/irunner.h b/extension/llm/runner/irunner.h index ef93f32319c..6699234be4d 100644 --- a/extension/llm/runner/irunner.h +++ b/extension/llm/runner/irunner.h @@ -125,6 +125,17 @@ class ET_EXPERIMENTAL IRunner { std::function token_callback, std::function stats_callback) = 0; + /** + * Prefill text inputs, for example to reload chat history. + * @param prompt Text prompt to prefill. + * @param config Configuration parameters (if non-zero num_bos and num_eos + * used) + * @return The error code. KV cache position is tracked internally in pos_. + */ + virtual ::executorch::runtime::Error prefill( + const std::string& prompt, + const GenerationConfig& config = {}) = 0; + /** * Stop the generation process. */ diff --git a/extension/llm/runner/multimodal_runner.cpp b/extension/llm/runner/multimodal_runner.cpp index a5de59cbe98..8b7e4e315d8 100644 --- a/extension/llm/runner/multimodal_runner.cpp +++ b/extension/llm/runner/multimodal_runner.cpp @@ -62,7 +62,7 @@ Error MultimodalRunner::load() { ET_LOG(Info, format, __VA_ARGS__); \ } -Error MultimodalRunner::prefill(std::vector& inputs) { +Error MultimodalRunner::prefill(const std::vector& inputs) { if (!is_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(load()); } diff --git a/extension/llm/runner/multimodal_runner.h b/extension/llm/runner/multimodal_runner.h index 4a824fd4d9c..caf3c296038 100644 --- a/extension/llm/runner/multimodal_runner.h +++ b/extension/llm/runner/multimodal_runner.h @@ -126,7 +126,7 @@ class ET_EXPERIMENTAL MultimodalRunner { * @return The error code. KV cache position is tracked internally in pos_. */ virtual ::executorch::runtime::Error prefill( - std::vector& inputs); + const std::vector& inputs); inline void stop() { text_token_generator_->stop(); diff --git a/extension/llm/runner/text_llm_runner.h b/extension/llm/runner/text_llm_runner.h index 98fcef94f96..865b8a3bd53 100644 --- a/extension/llm/runner/text_llm_runner.h +++ b/extension/llm/runner/text_llm_runner.h @@ -104,13 +104,13 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner { /** * Prefill text inputs, for example to reload chat history. * @param prompt Text prompt to prefill. - * @param config Configuration parameters for text generation (e.g., - * max_new_tokens, temperature) + * @param config Configuration parameters (if non-zero num_bos and num_eos + * used) * @return The error code. KV cache position is tracked internally in pos_. */ ::executorch::runtime::Error prefill( const std::string& prompt, - const GenerationConfig& config); + const GenerationConfig& config = {}) override; /** * @brief Warms up the model with a sample prompt From 148bd912325d270923590f386917c09fa03d8663 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 22 Sep 2025 19:06:00 -0700 Subject: [PATCH 10/14] QNN override --- examples/qualcomm/oss_scripts/llama/runner/runner.cpp | 7 +++++++ examples/qualcomm/oss_scripts/llama/runner/runner.h | 5 +++++ 2 files changed, 12 insertions(+) diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index fc4ff006a90..01da9ad058d 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -464,6 +464,13 @@ Error Runner::generate_from_prompt_or_file( return Error::Ok; } +template +::executorch::runtime::Error prefill( + const std::string& prompt, + const GenerationConfig& config = {}) { + return ::Error::NotImplemented; +} + template Result Runner::get_decoder_model_version() { if (!is_loaded()) { diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.h b/examples/qualcomm/oss_scripts/llama/runner/runner.h index 9f290d79c75..41b1db1d2b7 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.h @@ -79,6 +79,11 @@ class Runner : public executorch::extension::llm::IRunner { const executorch::extension::llm::GenerationConfig& config, std::function token_callback = {}, std::function stats_callback = {}); + + executorch::runtime::Error prefill( + const std::string& prompt, + const GenerationConfig& config = {}) override; + void stop() override {}; void stop() override {}; void reset() override {}; executorch::runtime::Result get_decoder_model_version(); From 0a192126e0283c722ea5a895655b1d231b6a0d0c Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 22 Sep 2025 19:31:09 -0700 Subject: [PATCH 11/14] fix --- examples/qualcomm/oss_scripts/llama/runner/runner.cpp | 2 +- examples/qualcomm/oss_scripts/llama/runner/runner.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index 01da9ad058d..818f3c2309e 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -467,7 +467,7 @@ Error Runner::generate_from_prompt_or_file( template ::executorch::runtime::Error prefill( const std::string& prompt, - const GenerationConfig& config = {}) { + const executorch::extension::llm::GenerationConfig& config = {}) { return ::Error::NotImplemented; } diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.h b/examples/qualcomm/oss_scripts/llama/runner/runner.h index 41b1db1d2b7..143b3824c11 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.h @@ -82,7 +82,7 @@ class Runner : public executorch::extension::llm::IRunner { executorch::runtime::Error prefill( const std::string& prompt, - const GenerationConfig& config = {}) override; + const executorch::extension::llm::GenerationConfig& config = {}) override; void stop() override {}; void stop() override {}; void reset() override {}; From 659007de1da393710baea55d36f531a97dd54794 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Mon, 22 Sep 2025 20:14:31 -0700 Subject: [PATCH 12/14] fix --- examples/qualcomm/oss_scripts/llama/runner/runner.h | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.h b/examples/qualcomm/oss_scripts/llama/runner/runner.h index 143b3824c11..41d1ae19bdc 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.h @@ -84,7 +84,6 @@ class Runner : public executorch::extension::llm::IRunner { const std::string& prompt, const executorch::extension::llm::GenerationConfig& config = {}) override; void stop() override {}; - void stop() override {}; void reset() override {}; executorch::runtime::Result get_decoder_model_version(); From f4de63a79e80f5ddeb0f520c3b3c85021ac00afe Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Tue, 23 Sep 2025 09:26:49 -0700 Subject: [PATCH 13/14] fix --- examples/qualcomm/oss_scripts/llama/runner/runner.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index 818f3c2309e..1c8b98e79a8 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -465,7 +465,7 @@ Error Runner::generate_from_prompt_or_file( } template -::executorch::runtime::Error prefill( +::executorch::runtime::Error Runner::prefill( const std::string& prompt, const executorch::extension::llm::GenerationConfig& config = {}) { return ::Error::NotImplemented; From 8745801664b4aa58e887828cb5c4b6de67ec54d5 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Tue, 23 Sep 2025 09:39:35 -0700 Subject: [PATCH 14/14] fix qnn compile --- examples/qualcomm/oss_scripts/llama/runner/runner.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index 1c8b98e79a8..14df30779aa 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -467,7 +467,7 @@ Error Runner::generate_from_prompt_or_file( template ::executorch::runtime::Error Runner::prefill( const std::string& prompt, - const executorch::extension::llm::GenerationConfig& config = {}) { + const executorch::extension::llm::GenerationConfig& config) { return ::Error::NotImplemented; }