Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,55 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) {
/**
* Generate tokens from the given prompt, starting from the given position.
*
* <p>This function is deprecated. Please use #generateFromPosWithPosUpdate( String prompt, int
* seqLen, long startPos, LlmCallback callback, boolean echo) instead.
*
* @param prompt The text prompt to LLaVA.
* @param seqLen The total sequence length, including the prompt tokens and new tokens.
* @param startPos The starting position in KV cache of the input in the LLM.
* @param callback callback object to receive results.
* @param echo indicate whether to echo the input prompt or not.
* @return The error code.
*/
public native int generateFromPos(
@Deprecated
public int generateFromPos(
String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo) {
long[] nativeResult = generateFromPosNative(prompt, seqLen, startPos, callback, echo);
return (int) nativeResult[0];
}

/**
* Generate tokens from the given prompt, starting from the given position.
*
* <p>This will return the updated startPos.
*
* @param prompt The text prompt to LLaVA.
* @param seqLen The total sequence length, including the prompt tokens and new tokens.
* @param startPos The starting position in KV cache of the input in the LLM.
* @param callback callback object to receive results.
* @param echo indicate whether to echo the input prompt or not.
* @return updated startPos
*/
public long generateFromPosWithPosUpdate(
String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo) {
long[] nativeResult = generateFromPosNative(prompt, seqLen, startPos, callback, echo);
if (nativeResult[0] != 0) {
throw new RuntimeException("Generate failed with error code: " + nativeResult[0]);
}
return nativeResult[1];
}

/**
* Generate tokens from the given prompt, starting from the given position.
*
* @param prompt The text prompt to LLaVA.
* @param seqLen The total sequence length, including the prompt tokens and new tokens.
* @param startPos The starting position in KV cache of the input in the LLM.
* @param callback callback object to receive results.
* @param echo indicate whether to echo the input prompt or not.
* @return a tuple of (status, updated startPos)
*/
private native long[] generateFromPosNative(
String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo);

/** Stop current generate() before it finishes. */
Expand Down
36 changes: 25 additions & 11 deletions extension/android/jni/jni_layer_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,32 +291,46 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
return tuple_result;
}

jint generate_from_pos(
facebook::jni::local_ref<jlongArray> generate_from_pos(
facebook::jni::alias_ref<jstring> prompt,
jint seq_len,
jlong start_pos,
facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
jboolean echo) {
facebook::jni::local_ref<jlongArray> tuple_result =
facebook::jni::make_long_array(2);

if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
return static_cast<jint>(multi_modal_runner_->generate_from_pos(
prompt->toStdString(),
seq_len,
start_pos,
[callback](const std::string& result) { callback->onResult(result); },
[callback](const llm::Stats& stats) { callback->onStats(stats); },
echo));
tuple_result->pin()[0] =
static_cast<jlong>(multi_modal_runner_->generate_from_pos(
prompt->toStdString(),
seq_len,
start_pos,
[callback](const std::string& result) {
callback->onResult(result);
},
[callback](const llm::Stats& stats) { callback->onStats(stats); },
echo));
tuple_result->pin()[1] = static_cast<jlong>(start_pos);
return tuple_result;
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
executorch::extension::llm::GenerationConfig config{
.echo = static_cast<bool>(echo),
.seq_len = seq_len,
.temperature = temperature_,
};
return static_cast<jint>(runner_->generate_from_pos(
int start_pos_after_generate = start_pos;
tuple_result->pin()[0] = static_cast<jlong>(runner_->generate_from_pos(
prompt->toStdString(),
start_pos,
config,
[callback](std::string result) { callback->onResult(result); },
[callback](const llm::Stats& stats) { callback->onStats(stats); }));
[callback](const llm::Stats& stats) { callback->onStats(stats); },
[callback, &start_pos_after_generate](int updated_start_pos) {
start_pos_after_generate = updated_start_pos;
}));
tuple_result->pin()[1] = static_cast<jlong>(start_pos_after_generate);
return tuple_result;
}
}

Expand Down Expand Up @@ -348,7 +362,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
makeNativeMethod(
"prefillPromptNative", ExecuTorchLlmJni::prefill_prompt),
makeNativeMethod(
"generateFromPos", ExecuTorchLlmJni::generate_from_pos),
"generateFromPosNative", ExecuTorchLlmJni::generate_from_pos),
});
}
};
Expand Down
4 changes: 3 additions & 1 deletion extension/llm/runner/irunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,16 @@ class ET_EXPERIMENTAL IRunner {
* @param config Generation configuration parameters
* @param token_callback Callback function called for each generated token
* @param stats_callback Callback function for generation statistics
* @param stats_callback Callback function for the updated start_pos
* @return Error::Ok if successful, an error otherwise
*/
virtual runtime::Error generate_from_pos(
const std::string& prompt,
int64_t start_pos,
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) = 0;
std::function<void(const Stats&)> stats_callback,
std::function<void(int)> updated_start_pos) = 0;
/**
* Stop the generation process.
*/
Expand Down
9 changes: 7 additions & 2 deletions extension/llm/runner/text_llm_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ Error TextLLMRunner::generate_from_pos(
int64_t start_pos,
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) {
std::function<void(const Stats&)> stats_callback,
std::function<void(int)> updated_start_pos) {
// Prepare the inputs.
// Use ones-initialized inputs.
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
Expand Down Expand Up @@ -230,6 +231,9 @@ Error TextLLMRunner::generate_from_pos(
if (stats_callback) {
stats_callback(*stats_);
}
if (updated_start_pos) {
updated_start_pos(pos);
}

return Error::Ok;
}
Expand All @@ -238,7 +242,8 @@ Error TextLLMRunner::generate(
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback,
std::function<void(const Stats&)> stats_callback) {
return generate_from_pos(prompt, 0, config, token_callback, stats_callback);
return generate_from_pos(
prompt, 0, config, token_callback, stats_callback, {});
}

Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) {
Expand Down
3 changes: 2 additions & 1 deletion extension/llm/runner/text_llm_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
int64_t start_pos,
const GenerationConfig& config,
std::function<void(const std::string&)> token_callback = {},
std::function<void(const Stats&)> stats_callback = {}) override;
std::function<void(const Stats&)> stats_callback = {},
std::function<void(int)> updated_start_pos = {}) override;

/**
* @brief Warms up the model with a sample prompt
Expand Down
Loading