diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java index b014ceb75d8..4e90dfe2a0d 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java @@ -213,6 +213,9 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { /** * Generate tokens from the given prompt, starting from the given position. * + *

This function is deprecated. Please use #generateFromPosWithPosUpdate( String prompt, int + * seqLen, long startPos, LlmCallback callback, boolean echo) instead. + * * @param prompt The text prompt to LLaVA. * @param seqLen The total sequence length, including the prompt tokens and new tokens. * @param startPos The starting position in KV cache of the input in the LLM. @@ -220,7 +223,45 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { * @param echo indicate whether to echo the input prompt or not. * @return The error code. */ - public native int generateFromPos( + @Deprecated + public int generateFromPos( + String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo) { + long[] nativeResult = generateFromPosNative(prompt, seqLen, startPos, callback, echo); + return (int) nativeResult[0]; + } + + /** + * Generate tokens from the given prompt, starting from the given position. + * + *

This will return the updated startPos. + * + * @param prompt The text prompt to LLaVA. + * @param seqLen The total sequence length, including the prompt tokens and new tokens. + * @param startPos The starting position in KV cache of the input in the LLM. + * @param callback callback object to receive results. + * @param echo indicate whether to echo the input prompt or not. + * @return updated startPos + */ + public long generateFromPosWithPosUpdate( + String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo) { + long[] nativeResult = generateFromPosNative(prompt, seqLen, startPos, callback, echo); + if (nativeResult[0] != 0) { + throw new RuntimeException("Generate failed with error code: " + nativeResult[0]); + } + return nativeResult[1]; + } + + /** + * Generate tokens from the given prompt, starting from the given position. + * + * @param prompt The text prompt to LLaVA. + * @param seqLen The total sequence length, including the prompt tokens and new tokens. + * @param startPos The starting position in KV cache of the input in the LLM. + * @param callback callback object to receive results. + * @param echo indicate whether to echo the input prompt or not. + * @return a tuple of (status, updated startPos) + */ + private native long[] generateFromPosNative( String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo); /** Stop current generate() before it finishes. */ diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 257f7282c65..0b80244bfd5 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -291,32 +291,46 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { return tuple_result; } - jint generate_from_pos( + facebook::jni::local_ref generate_from_pos( facebook::jni::alias_ref prompt, jint seq_len, jlong start_pos, facebook::jni::alias_ref callback, jboolean echo) { + facebook::jni::local_ref tuple_result = + facebook::jni::make_long_array(2); + if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { - return static_cast(multi_modal_runner_->generate_from_pos( - prompt->toStdString(), - seq_len, - start_pos, - [callback](const std::string& result) { callback->onResult(result); }, - [callback](const llm::Stats& stats) { callback->onStats(stats); }, - echo)); + tuple_result->pin()[0] = + static_cast(multi_modal_runner_->generate_from_pos( + prompt->toStdString(), + seq_len, + start_pos, + [callback](const std::string& result) { + callback->onResult(result); + }, + [callback](const llm::Stats& stats) { callback->onStats(stats); }, + echo)); + tuple_result->pin()[1] = static_cast(start_pos); + return tuple_result; } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { executorch::extension::llm::GenerationConfig config{ .echo = static_cast(echo), .seq_len = seq_len, .temperature = temperature_, }; - return static_cast(runner_->generate_from_pos( + int start_pos_after_generate = start_pos; + tuple_result->pin()[0] = static_cast(runner_->generate_from_pos( prompt->toStdString(), start_pos, config, [callback](std::string result) { callback->onResult(result); }, - [callback](const llm::Stats& stats) { callback->onStats(stats); })); + [callback](const llm::Stats& stats) { callback->onStats(stats); }, + [callback, &start_pos_after_generate](int updated_start_pos) { + start_pos_after_generate = updated_start_pos; + })); + tuple_result->pin()[1] = static_cast(start_pos_after_generate); + return tuple_result; } } @@ -348,7 +362,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { makeNativeMethod( "prefillPromptNative", ExecuTorchLlmJni::prefill_prompt), makeNativeMethod( - "generateFromPos", ExecuTorchLlmJni::generate_from_pos), + "generateFromPosNative", ExecuTorchLlmJni::generate_from_pos), }); } }; diff --git a/extension/llm/runner/irunner.h b/extension/llm/runner/irunner.h index 4c2efc91203..50946fb0153 100644 --- a/extension/llm/runner/irunner.h +++ b/extension/llm/runner/irunner.h @@ -134,6 +134,7 @@ class ET_EXPERIMENTAL IRunner { * @param config Generation configuration parameters * @param token_callback Callback function called for each generated token * @param stats_callback Callback function for generation statistics + * @param stats_callback Callback function for the updated start_pos * @return Error::Ok if successful, an error otherwise */ virtual runtime::Error generate_from_pos( @@ -141,7 +142,8 @@ class ET_EXPERIMENTAL IRunner { int64_t start_pos, const GenerationConfig& config, std::function token_callback, - std::function stats_callback) = 0; + std::function stats_callback, + std::function updated_start_pos) = 0; /** * Stop the generation process. */ diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index 4f89121111d..83c9e5b83bd 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -91,7 +91,8 @@ Error TextLLMRunner::generate_from_pos( int64_t start_pos, const GenerationConfig& config, std::function token_callback, - std::function stats_callback) { + std::function stats_callback, + std::function updated_start_pos) { // Prepare the inputs. // Use ones-initialized inputs. ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); @@ -230,6 +231,9 @@ Error TextLLMRunner::generate_from_pos( if (stats_callback) { stats_callback(*stats_); } + if (updated_start_pos) { + updated_start_pos(pos); + } return Error::Ok; } @@ -238,7 +242,8 @@ Error TextLLMRunner::generate( const GenerationConfig& config, std::function token_callback, std::function stats_callback) { - return generate_from_pos(prompt, 0, config, token_callback, stats_callback); + return generate_from_pos( + prompt, 0, config, token_callback, stats_callback, {}); } Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) { diff --git a/extension/llm/runner/text_llm_runner.h b/extension/llm/runner/text_llm_runner.h index c35f143d2e0..69d3fc33fd7 100644 --- a/extension/llm/runner/text_llm_runner.h +++ b/extension/llm/runner/text_llm_runner.h @@ -119,7 +119,8 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner { int64_t start_pos, const GenerationConfig& config, std::function token_callback = {}, - std::function stats_callback = {}) override; + std::function stats_callback = {}, + std::function updated_start_pos = {}) override; /** * @brief Warms up the model with a sample prompt