diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java
index b014ceb75d8..4e90dfe2a0d 100644
--- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java
+++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java
@@ -213,6 +213,9 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) {
/**
* Generate tokens from the given prompt, starting from the given position.
*
+ *
This function is deprecated. Please use #generateFromPosWithPosUpdate( String prompt, int
+ * seqLen, long startPos, LlmCallback callback, boolean echo) instead.
+ *
* @param prompt The text prompt to LLaVA.
* @param seqLen The total sequence length, including the prompt tokens and new tokens.
* @param startPos The starting position in KV cache of the input in the LLM.
@@ -220,7 +223,45 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) {
* @param echo indicate whether to echo the input prompt or not.
* @return The error code.
*/
- public native int generateFromPos(
+ @Deprecated
+ public int generateFromPos(
+ String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo) {
+ long[] nativeResult = generateFromPosNative(prompt, seqLen, startPos, callback, echo);
+ return (int) nativeResult[0];
+ }
+
+ /**
+ * Generate tokens from the given prompt, starting from the given position.
+ *
+ *
This will return the updated startPos.
+ *
+ * @param prompt The text prompt to LLaVA.
+ * @param seqLen The total sequence length, including the prompt tokens and new tokens.
+ * @param startPos The starting position in KV cache of the input in the LLM.
+ * @param callback callback object to receive results.
+ * @param echo indicate whether to echo the input prompt or not.
+ * @return updated startPos
+ */
+ public long generateFromPosWithPosUpdate(
+ String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo) {
+ long[] nativeResult = generateFromPosNative(prompt, seqLen, startPos, callback, echo);
+ if (nativeResult[0] != 0) {
+ throw new RuntimeException("Generate failed with error code: " + nativeResult[0]);
+ }
+ return nativeResult[1];
+ }
+
+ /**
+ * Generate tokens from the given prompt, starting from the given position.
+ *
+ * @param prompt The text prompt to LLaVA.
+ * @param seqLen The total sequence length, including the prompt tokens and new tokens.
+ * @param startPos The starting position in KV cache of the input in the LLM.
+ * @param callback callback object to receive results.
+ * @param echo indicate whether to echo the input prompt or not.
+ * @return a tuple of (status, updated startPos)
+ */
+ private native long[] generateFromPosNative(
String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo);
/** Stop current generate() before it finishes. */
diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp
index 257f7282c65..0b80244bfd5 100644
--- a/extension/android/jni/jni_layer_llama.cpp
+++ b/extension/android/jni/jni_layer_llama.cpp
@@ -291,32 +291,46 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass {
return tuple_result;
}
- jint generate_from_pos(
+ facebook::jni::local_ref generate_from_pos(
facebook::jni::alias_ref prompt,
jint seq_len,
jlong start_pos,
facebook::jni::alias_ref callback,
jboolean echo) {
+ facebook::jni::local_ref tuple_result =
+ facebook::jni::make_long_array(2);
+
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
- return static_cast(multi_modal_runner_->generate_from_pos(
- prompt->toStdString(),
- seq_len,
- start_pos,
- [callback](const std::string& result) { callback->onResult(result); },
- [callback](const llm::Stats& stats) { callback->onStats(stats); },
- echo));
+ tuple_result->pin()[0] =
+ static_cast(multi_modal_runner_->generate_from_pos(
+ prompt->toStdString(),
+ seq_len,
+ start_pos,
+ [callback](const std::string& result) {
+ callback->onResult(result);
+ },
+ [callback](const llm::Stats& stats) { callback->onStats(stats); },
+ echo));
+ tuple_result->pin()[1] = static_cast(start_pos);
+ return tuple_result;
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
executorch::extension::llm::GenerationConfig config{
.echo = static_cast(echo),
.seq_len = seq_len,
.temperature = temperature_,
};
- return static_cast(runner_->generate_from_pos(
+ int start_pos_after_generate = start_pos;
+ tuple_result->pin()[0] = static_cast(runner_->generate_from_pos(
prompt->toStdString(),
start_pos,
config,
[callback](std::string result) { callback->onResult(result); },
- [callback](const llm::Stats& stats) { callback->onStats(stats); }));
+ [callback](const llm::Stats& stats) { callback->onStats(stats); },
+ [callback, &start_pos_after_generate](int updated_start_pos) {
+ start_pos_after_generate = updated_start_pos;
+ }));
+ tuple_result->pin()[1] = static_cast(start_pos_after_generate);
+ return tuple_result;
}
}
@@ -348,7 +362,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass {
makeNativeMethod(
"prefillPromptNative", ExecuTorchLlmJni::prefill_prompt),
makeNativeMethod(
- "generateFromPos", ExecuTorchLlmJni::generate_from_pos),
+ "generateFromPosNative", ExecuTorchLlmJni::generate_from_pos),
});
}
};
diff --git a/extension/llm/runner/irunner.h b/extension/llm/runner/irunner.h
index 4c2efc91203..50946fb0153 100644
--- a/extension/llm/runner/irunner.h
+++ b/extension/llm/runner/irunner.h
@@ -134,6 +134,7 @@ class ET_EXPERIMENTAL IRunner {
* @param config Generation configuration parameters
* @param token_callback Callback function called for each generated token
* @param stats_callback Callback function for generation statistics
+ * @param stats_callback Callback function for the updated start_pos
* @return Error::Ok if successful, an error otherwise
*/
virtual runtime::Error generate_from_pos(
@@ -141,7 +142,8 @@ class ET_EXPERIMENTAL IRunner {
int64_t start_pos,
const GenerationConfig& config,
std::function token_callback,
- std::function stats_callback) = 0;
+ std::function stats_callback,
+ std::function updated_start_pos) = 0;
/**
* Stop the generation process.
*/
diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp
index 4f89121111d..83c9e5b83bd 100644
--- a/extension/llm/runner/text_llm_runner.cpp
+++ b/extension/llm/runner/text_llm_runner.cpp
@@ -91,7 +91,8 @@ Error TextLLMRunner::generate_from_pos(
int64_t start_pos,
const GenerationConfig& config,
std::function token_callback,
- std::function stats_callback) {
+ std::function stats_callback,
+ std::function updated_start_pos) {
// Prepare the inputs.
// Use ones-initialized inputs.
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
@@ -230,6 +231,9 @@ Error TextLLMRunner::generate_from_pos(
if (stats_callback) {
stats_callback(*stats_);
}
+ if (updated_start_pos) {
+ updated_start_pos(pos);
+ }
return Error::Ok;
}
@@ -238,7 +242,8 @@ Error TextLLMRunner::generate(
const GenerationConfig& config,
std::function token_callback,
std::function stats_callback) {
- return generate_from_pos(prompt, 0, config, token_callback, stats_callback);
+ return generate_from_pos(
+ prompt, 0, config, token_callback, stats_callback, {});
}
Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) {
diff --git a/extension/llm/runner/text_llm_runner.h b/extension/llm/runner/text_llm_runner.h
index c35f143d2e0..69d3fc33fd7 100644
--- a/extension/llm/runner/text_llm_runner.h
+++ b/extension/llm/runner/text_llm_runner.h
@@ -119,7 +119,8 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
int64_t start_pos,
const GenerationConfig& config,
std::function token_callback = {},
- std::function stats_callback = {}) override;
+ std::function stats_callback = {},
+ std::function updated_start_pos = {}) override;
/**
* @brief Warms up the model with a sample prompt