Skip to content

Commit 2980824

Browse files
committed
Test new api
1 parent 9e00a51 commit 2980824

File tree

5 files changed

+78
-15
lines changed

5 files changed

+78
-15
lines changed

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmModule.java

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,14 +213,55 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) {
213213
/**
214214
* Generate tokens from the given prompt, starting from the given position.
215215
*
216+
* <p>This function is deprecated. Please use #generateFromPosWithPosUpdate( String prompt, int
217+
* seqLen, long startPos, LlmCallback callback, boolean echo) instead.
218+
*
216219
* @param prompt The text prompt to LLaVA.
217220
* @param seqLen The total sequence length, including the prompt tokens and new tokens.
218221
* @param startPos The starting position in KV cache of the input in the LLM.
219222
* @param callback callback object to receive results.
220223
* @param echo indicate whether to echo the input prompt or not.
221224
* @return The error code.
222225
*/
223-
public native int generateFromPos(
226+
@Deprecated
227+
public int generateFromPos(
228+
String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo) {
229+
long[] nativeResult = generateFromPosNative(prompt, seqLen, startPos, callback, echo);
230+
return nativeResult[0];
231+
}
232+
233+
/**
234+
* Generate tokens from the given prompt, starting from the given position.
235+
*
236+
* <p>This will return the updated startPos.
237+
*
238+
* @param prompt The text prompt to LLaVA.
239+
* @param seqLen The total sequence length, including the prompt tokens and new tokens.
240+
* @param startPos The starting position in KV cache of the input in the LLM.
241+
* @param callback callback object to receive results.
242+
* @param echo indicate whether to echo the input prompt or not.
243+
* @return updated startPos
244+
*/
245+
public long generateFromPosWithPosUpdate(
246+
String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo) {
247+
long[] nativeResult = generateFromPosNative(prompt, seqLen, startPos, callback, echo);
248+
if (nativeResult[0] != 0) {
249+
throw new RuntimeException("Generate failed with error code: " + nativeResult[0]);
250+
}
251+
return nativeResult[1];
252+
}
253+
254+
/**
255+
* Generate tokens from the given prompt, starting from the given position.
256+
*
257+
* @param prompt The text prompt to LLaVA.
258+
* @param seqLen The total sequence length, including the prompt tokens and new tokens.
259+
* @param startPos The starting position in KV cache of the input in the LLM.
260+
* @param callback callback object to receive results.
261+
* @param echo indicate whether to echo the input prompt or not.
262+
* @return a tuple of (status, updated startPos)
263+
*/
264+
private native long[] generateFromPosNative(
224265
String prompt, int seqLen, long startPos, LlmCallback callback, boolean echo);
225266

226267
/** Stop current generate() before it finishes. */

extension/android/jni/jni_layer_llama.cpp

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -291,32 +291,46 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
291291
return tuple_result;
292292
}
293293

294-
jint generate_from_pos(
294+
facebook::jni::local_ref<jlongArray> generate_from_pos(
295295
facebook::jni::alias_ref<jstring> prompt,
296296
jint seq_len,
297297
jlong start_pos,
298298
facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
299299
jboolean echo) {
300+
facebook::jni::local_ref<jlongArray> tuple_result =
301+
facebook::jni::make_long_array(2);
302+
300303
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
301-
return static_cast<jint>(multi_modal_runner_->generate_from_pos(
302-
prompt->toStdString(),
303-
seq_len,
304-
start_pos,
305-
[callback](const std::string& result) { callback->onResult(result); },
306-
[callback](const llm::Stats& stats) { callback->onStats(stats); },
307-
echo));
304+
tuple_result->pin()[0] =
305+
static_cast<jlong>(multi_modal_runner_->generate_from_pos(
306+
prompt->toStdString(),
307+
seq_len,
308+
start_pos,
309+
[callback](const std::string& result) {
310+
callback->onResult(result);
311+
},
312+
[callback](const llm::Stats& stats) { callback->onStats(stats); },
313+
echo));
314+
tuple_result->pin()[1] = static_cast<jlong>(start_pos);
315+
return tuple_result;
308316
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
309317
executorch::extension::llm::GenerationConfig config{
310318
.echo = static_cast<bool>(echo),
311319
.seq_len = seq_len,
312320
.temperature = temperature_,
313321
};
314-
return static_cast<jint>(runner_->generate_from_pos(
322+
int start_pos_after_generate = start_pos;
323+
tuple_result->pin()[0] = static_cast<jlong>(runner_->generate_from_pos(
315324
prompt->toStdString(),
316325
start_pos,
317326
config,
318327
[callback](std::string result) { callback->onResult(result); },
319-
[callback](const llm::Stats& stats) { callback->onStats(stats); }));
328+
[callback](const llm::Stats& stats) { callback->onStats(stats); },
329+
[callback, &start_pos_after_generate](int updated_start_pos) {
330+
start_pos_after_generate = updated_start_pos;
331+
}));
332+
tuple_result->pin()[1] = static_cast<jlong>(start_pos_after_generate);
333+
return tuple_result;
320334
}
321335
}
322336

extension/llm/runner/irunner.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,16 @@ class ET_EXPERIMENTAL IRunner {
134134
* @param config Generation configuration parameters
135135
* @param token_callback Callback function called for each generated token
136136
* @param stats_callback Callback function for generation statistics
137+
* @param stats_callback Callback function for the updated start_pos
137138
* @return Error::Ok if successful, an error otherwise
138139
*/
139140
virtual runtime::Error generate_from_pos(
140141
const std::string& prompt,
141142
int64_t start_pos,
142143
const GenerationConfig& config,
143144
std::function<void(const std::string&)> token_callback,
144-
std::function<void(const Stats&)> stats_callback) = 0;
145+
std::function<void(const Stats&)> stats_callback,
146+
std::function<void(int)> updated_start_pos) = 0;
145147
/**
146148
* Stop the generation process.
147149
*/

extension/llm/runner/text_llm_runner.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,8 @@ Error TextLLMRunner::generate_from_pos(
9191
int64_t start_pos,
9292
const GenerationConfig& config,
9393
std::function<void(const std::string&)> token_callback,
94-
std::function<void(const Stats&)> stats_callback) {
94+
std::function<void(const Stats&)> stats_callback,
95+
std::function<void(int)> updated_start_pos) {
9596
// Prepare the inputs.
9697
// Use ones-initialized inputs.
9798
ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
@@ -230,6 +231,9 @@ Error TextLLMRunner::generate_from_pos(
230231
if (stats_callback) {
231232
stats_callback(*stats_);
232233
}
234+
if (updated_start_pos) {
235+
updated_start_pos(pos);
236+
}
233237

234238
return Error::Ok;
235239
}
@@ -238,7 +242,8 @@ Error TextLLMRunner::generate(
238242
const GenerationConfig& config,
239243
std::function<void(const std::string&)> token_callback,
240244
std::function<void(const Stats&)> stats_callback) {
241-
return generate_from_pos(prompt, 0, config, token_callback, stats_callback);
245+
return generate_from_pos(
246+
prompt, 0, config, token_callback, stats_callback, {});
242247
}
243248

244249
Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) {

extension/llm/runner/text_llm_runner.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
119119
int64_t start_pos,
120120
const GenerationConfig& config,
121121
std::function<void(const std::string&)> token_callback = {},
122-
std::function<void(const Stats&)> stats_callback = {}) override;
122+
std::function<void(const Stats&)> stats_callback = {},
123+
std::function<void(int)> updated_start_pos = {}) override;
123124

124125
/**
125126
* @brief Warms up the model with a sample prompt

0 commit comments

Comments
 (0)