Skip to content

Commit db8d04f

Browse files
authored
[multimodal] Allow generate and prefill to take move sematics (#14643)
As titled
1 parent ebf4c12 commit db8d04f

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

extension/llm/runner/multimodal_runner.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ Error MultimodalRunner::load() {
6262
ET_LOG(Info, format, __VA_ARGS__); \
6363
}
6464

65+
Error MultimodalRunner::prefill(std::vector<MultimodalInput>&& inputs) {
66+
// Forward to the const reference version
67+
return prefill(inputs);
68+
}
69+
6570
Error MultimodalRunner::prefill(const std::vector<MultimodalInput>& inputs) {
6671
if (!is_loaded()) {
6772
ET_CHECK_OK_OR_RETURN_ERROR(load());
@@ -72,6 +77,16 @@ Error MultimodalRunner::prefill(const std::vector<MultimodalInput>& inputs) {
7277
return Error::Ok;
7378
}
7479

80+
Error MultimodalRunner::generate(
81+
std::vector<MultimodalInput>&& inputs,
82+
const GenerationConfig& config,
83+
std::function<void(const std::string&)> token_callback,
84+
std::function<void(const Stats&)> stats_callback) {
85+
// Forward to the const reference version
86+
return generate(
87+
inputs, config, std::move(token_callback), std::move(stats_callback));
88+
}
89+
7590
Error MultimodalRunner::generate(
7691
const std::vector<MultimodalInput>& inputs,
7792
const GenerationConfig& config,

extension/llm/runner/multimodal_runner.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,21 @@ class ET_EXPERIMENTAL MultimodalRunner {
119119
std::function<void(const std::string&)> token_callback = {},
120120
std::function<void(const Stats&)> stats_callback = {});
121121

122+
/**
123+
* Generate tokens from multimodal inputs with move semantics.
124+
* This overload allows efficient transfer of temporary vectors.
125+
* @param inputs A vector of MultimodalInput objects (moved).
126+
* @param config Generation configuration parameters.
127+
* @param token_callback Callback function called for each generated token.
128+
* @param stats_callback Callback function for generation statistics.
129+
* @return The error code. KV cache position is tracked internally in pos_.
130+
*/
131+
virtual ::executorch::runtime::Error generate(
132+
std::vector<MultimodalInput>&& inputs,
133+
const GenerationConfig& config,
134+
std::function<void(const std::string&)> token_callback = {},
135+
std::function<void(const Stats&)> stats_callback = {});
136+
122137
/**
123138
* Prefill multimodal inputs, for example to reload chat history.
124139
* @param inputs A vector of MultimodalInput objects containing images and
@@ -128,6 +143,15 @@ class ET_EXPERIMENTAL MultimodalRunner {
128143
virtual ::executorch::runtime::Error prefill(
129144
const std::vector<MultimodalInput>& inputs);
130145

146+
/**
147+
* Prefill multimodal inputs with move semantics.
148+
* This overload allows efficient transfer of temporary vectors.
149+
* @param inputs A vector of MultimodalInput objects (moved).
150+
* @return The error code. KV cache position is tracked internally in pos_.
151+
*/
152+
virtual ::executorch::runtime::Error prefill(
153+
std::vector<MultimodalInput>&& inputs);
154+
131155
inline void stop() {
132156
text_token_generator_->stop();
133157
}

0 commit comments

Comments
 (0)