From 29808245e7e077259001f73b556aff559bac8cd8 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Wed, 30 Jul 2025 22:34:39 -0700 Subject: [PATCH 1/3] Test new api --- .../executorch/extension/llm/LlmModule.java | 43 ++++++++++++++++++- extension/android/jni/jni_layer_llama.cpp | 34 ++++++++++----- extension/llm/runner/irunner.h | 4 +- extension/llm/runner/text_llm_runner.cpp | 9 +++- extension/llm/runner/text_llm_runner.h | 3 +- 5 files changed, 78 insertions(+), 15 deletions(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index b014ceb75d8..a8eb4530e42 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -213,6 +213,9 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { /** * Generate tokens from the given prompt, starting from the given position. * + *

This function is deprecated. Please use #generateFromPosWithPosUpdate( String prompt, int + * seqLen, long startPos, LlmCallback callback, boolean echo) instead. + * * @param prompt The text prompt to LLaVA. * @param seqLen The total sequence length, including the prompt tokens and new tokens. * @param startPos The starting position in KV cache of the input in the LLM. @@ -220,7 +223,45 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { * @param echo indicate whether to echo the input prompt or not. * @return The error code. */ - public native int generateFromPos( + @Deprecated + public int generateFromPos( + String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo) { + long[] nativeResult = generateFromPosNative(prompt, seqLen, startPos, callback, echo); + return nativeResult[0]; + } + + /** + * Generate tokens from the given prompt, starting from the given position. + * + *

This will return the updated startPos. + * + * @param prompt The text prompt to LLaVA. + * @param seqLen The total sequence length, including the prompt tokens and new tokens. + * @param startPos The starting position in KV cache of the input in the LLM. + * @param callback callback object to receive results. + * @param echo indicate whether to echo the input prompt or not. + * @return updated startPos + */ + public long generateFromPosWithPosUpdate( + String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo) { + long[] nativeResult = generateFromPosNative(prompt, seqLen, startPos, callback, echo); + if (nativeResult[0] != 0) { + throw new RuntimeException("Generate failed with error code: " + nativeResult[0]); + } + return nativeResult[1]; + } + + /** + * Generate tokens from the given prompt, starting from the given position. + * + * @param prompt The text prompt to LLaVA. + * @param seqLen The total sequence length, including the prompt tokens and new tokens. + * @param startPos The starting position in KV cache of the input in the LLM. + * @param callback callback object to receive results. + * @param echo indicate whether to echo the input prompt or not. + * @return a tuple of (status, updated startPos) + */ + private native long[] generateFromPosNative( String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo); /** Stop current generate() before it finishes. */ diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 257f7282c65..8ebba60341d 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -291,32 +291,46 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { return tuple_result; } - jint generate_from_pos( + facebook::jni::local_ref generate_from_pos( facebook::jni::alias_ref prompt, jint seq_len, jlong start_pos, facebook::jni::alias_ref callback, jboolean echo) { + facebook::jni::local_ref tuple_result = + facebook::jni::make_long_array(2); + if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { - return static_cast(multi_modal_runner_->generate_from_pos( - prompt->toStdString(), - seq_len, - start_pos, - [callback](const std::string& result) { callback->onResult(result); }, - [callback](const llm::Stats& stats) { callback->onStats(stats); }, - echo)); + tuple_result->pin()[0] = + static_cast(multi_modal_runner_->generate_from_pos( + prompt->toStdString(), + seq_len, + start_pos, + [callback](const std::string& result) { + callback->onResult(result); + }, + [callback](const llm::Stats& stats) { callback->onStats(stats); }, + echo)); + tuple_result->pin()[1] = static_cast(start_pos); + return tuple_result; } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { executorch::extension::llm::GenerationConfig config{ .echo = static_cast(echo), .seq_len = seq_len, .temperature = temperature_, }; - return static_cast(runner_->generate_from_pos( + int start_pos_after_generate = start_pos; + tuple_result->pin()[0] = static_cast(runner_->generate_from_pos( prompt->toStdString(), start_pos, config, [callback](std::string result) { callback->onResult(result); }, - [callback](const llm::Stats& stats) { callback->onStats(stats); })); + [callback](const llm::Stats& stats) { callback->onStats(stats); }, + [callback, &start_pos_after_generate](int updated_start_pos) { + start_pos_after_generate = updated_start_pos; + })); + tuple_result->pin()[1] = static_cast(start_pos_after_generate); + return tuple_result; } } diff --git a/extension/llm/runner/irunner.h b/extension/llm/runner/irunner.h index 4c2efc91203..50946fb0153 100644 --- a/extension/llm/runner/irunner.h +++ b/extension/llm/runner/irunner.h @@ -134,6 +134,7 @@ class ET_EXPERIMENTAL IRunner { * @param config Generation configuration parameters * @param token_callback Callback function called for each generated token * @param stats_callback Callback function for generation statistics + * @param stats_callback Callback function for the updated start_pos * @return Error::Ok if successful, an error otherwise */ virtual runtime::Error generate_from_pos( @@ -141,7 +142,8 @@ class ET_EXPERIMENTAL IRunner { int64_t start_pos, const GenerationConfig& config, std::function token_callback, - std::function stats_callback) = 0; + std::function stats_callback, + std::function updated_start_pos) = 0; /** * Stop the generation process. */ diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index 4f89121111d..83c9e5b83bd 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -91,7 +91,8 @@ Error TextLLMRunner::generate_from_pos( int64_t start_pos, const GenerationConfig& config, std::function token_callback, - std::function stats_callback) { + std::function stats_callback, + std::function updated_start_pos) { // Prepare the inputs. // Use ones-initialized inputs. ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); @@ -230,6 +231,9 @@ Error TextLLMRunner::generate_from_pos( if (stats_callback) { stats_callback(*stats_); } + if (updated_start_pos) { + updated_start_pos(pos); + } return Error::Ok; } @@ -238,7 +242,8 @@ Error TextLLMRunner::generate( const GenerationConfig& config, std::function token_callback, std::function stats_callback) { - return generate_from_pos(prompt, 0, config, token_callback, stats_callback); + return generate_from_pos( + prompt, 0, config, token_callback, stats_callback, {}); } Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) { diff --git a/extension/llm/runner/text_llm_runner.h b/extension/llm/runner/text_llm_runner.h index c35f143d2e0..69d3fc33fd7 100644 --- a/extension/llm/runner/text_llm_runner.h +++ b/extension/llm/runner/text_llm_runner.h @@ -119,7 +119,8 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner { int64_t start_pos, const GenerationConfig& config, std::function token_callback = {}, - std::function stats_callback = {}) override; + std::function stats_callback = {}, + std::function updated_start_pos = {}) override; /** * @brief Warms up the model with a sample prompt From dd656b08253122781eaa26199ef78ff874a31b33 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Wed, 30 Jul 2025 23:01:50 -0700 Subject: [PATCH 2/3] update --- .../java/org/pytorch/executorch/extension/llm/LlmModule.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index a8eb4530e42..4e90dfe2a0d 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -227,7 +227,7 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { public int generateFromPos( String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo) { long[] nativeResult = generateFromPosNative(prompt, seqLen, startPos, callback, echo); - return nativeResult[0]; + return (int) nativeResult[0]; } /** From fdf91adcd9cd19b2d32a4343f5974ee26a443294 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Wed, 30 Jul 2025 23:06:36 -0700 Subject: [PATCH 3/3] update --- extension/android/jni/jni_layer_llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 8ebba60341d..0b80244bfd5 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -362,7 +362,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { makeNativeMethod( "prefillPromptNative", ExecuTorchLlmJni::prefill_prompt), makeNativeMethod( - "generateFromPos", ExecuTorchLlmJni::generate_from_pos), + "generateFromPosNative", ExecuTorchLlmJni::generate_from_pos), }); } };