Skip to content

Commit 9238a2b

Browse files
Add prefill API to MultimodalRunner (#14489)
Co-authored-by: Rohan Joshi <[email protected]>
1 parent c1c5c84 commit 9238a2b

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

extension/llm/runner/multimodal_runner.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,16 @@ Error MultimodalRunner::load() {
6262
ET_LOG(Info, format, __VA_ARGS__); \
6363
}
6464

65+
Error MultimodalRunner::prefill(std::vector<MultimodalInput>& inputs) {
66+
if (!is_loaded()) {
67+
ET_CHECK_OK_OR_RETURN_ERROR(load());
68+
}
69+
for (auto& input : inputs) {
70+
ET_UNWRAP(multimodal_prefiller_->prefill(input, pos_));
71+
}
72+
return Error::Ok;
73+
}
74+
6575
Error MultimodalRunner::generate(
6676
const std::vector<MultimodalInput>& inputs,
6777
const GenerationConfig& config,

extension/llm/runner/multimodal_runner.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,15 @@ class ET_EXPERIMENTAL MultimodalRunner {
119119
std::function<void(const std::string&)> token_callback = {},
120120
std::function<void(const Stats&)> stats_callback = {});
121121

122+
/**
123+
* Prefill multimodal inputs, for example to reload chat history.
124+
* @param inputs A vector of MultimodalInput objects containing images and
125+
* text.
126+
* @return The error code. KV cache position is tracked internally in pos_.
127+
*/
128+
virtual ::executorch::runtime::Error prefill(
129+
std::vector<MultimodalInput>& inputs);
130+
122131
inline void stop() {
123132
text_token_generator_->stop();
124133
}

0 commit comments

Comments
 (0)