Skip to content

Commit 80c6378

Browse files
committed
Merge branch 'android-use-prefill-api' into start-pos-api-llava-7-9
2 parents 392c157 + beb1784 commit 80c6378

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

extension/llm/runner/text_llm_runner.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,28 @@ Error TextLLMRunner::generate(
217217
return Error::Ok;
218218
}
219219

220+
Error TextLLMRunner::prefill(
221+
const std::string& prompt,
222+
const GenerationConfig& config) {
223+
if (!is_loaded()) {
224+
ET_CHECK_OK_OR_RETURN_ERROR(load());
225+
}
226+
227+
::tokenizers::Result<std::vector<uint64_t>> encode_res = tokenizer_->encode(
228+
prompt,
229+
/*bos=*/config.num_bos,
230+
/*eos=*/config.num_eos);
231+
232+
ET_CHECK_TK_OK_OR_RETURN_ERROR(
233+
encode_res.error(), "Failed to encode prompt %s", prompt.c_str());
234+
235+
// encode the (string) prompt into tokens sequence
236+
std::vector<uint64_t> prompt_tokens = encode_res.get();
237+
auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos_);
238+
ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error());
239+
return Error::Ok;
240+
}
241+
220242
Error TextLLMRunner::warmup(const std::string& prompt, int32_t max_new_tokens) {
221243
// Create a GenerationConfig for warmup
222244
GenerationConfig config{

extension/llm/runner/text_llm_runner.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,17 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
101101
std::function<void(const std::string&)> token_callback = {},
102102
std::function<void(const Stats&)> stats_callback = {}) override;
103103

104+
/**
105+
* Prefill text inputs, for example to reload chat history.
106+
* @param prompt Text prompt to prefill.
107+
* @param config Configuration parameters for text generation (e.g.,
108+
* max_new_tokens, temperature)
109+
* @return The error code. KV cache position is tracked internally in pos_.
110+
*/
111+
::executorch::runtime::Error prefill(
112+
const std::string& prompt,
113+
const GenerationConfig& config);
114+
104115
/**
105116
* @brief Warms up the model with a sample prompt
106117
*

0 commit comments

Comments
 (0)