Skip to content

Commit 1cd3582

Browse files
committed
Android use new prefill API
1 parent 80c6378 commit 1cd3582

File tree

5 files changed

+29
-16
lines changed

5 files changed

+29
-16
lines changed

extension/android/jni/jni_layer_llama.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
123123
std::unique_ptr<llm::IRunner> runner_;
124124
std::unique_ptr<executorch::extension::llm::MultimodalRunner>
125125
multi_modal_runner_;
126-
std::vector<llm::MultimodalInput> prefill_inputs_;
127126

128127
public:
129128
constexpr static auto kJavaDescriptor =
@@ -213,8 +212,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
213212
facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
214213
jboolean echo) {
215214
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
216-
std::vector<llm::MultimodalInput> inputs = prefill_inputs_;
217-
prefill_inputs_.clear();
215+
std::vector<llm::MultimodalInput> inputs;
218216
if (!prompt->toStdString().empty()) {
219217
inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()});
220218
}
@@ -247,9 +245,11 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
247245
// Contract is valid within an AAR (JNI + corresponding Java code)
248246
jint prefill_text_input(facebook::jni::alias_ref<jstring> prompt) {
249247
if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
250-
prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()});
248+
runner_->prefill(prompt->toStdString(), {});
249+
return 0;
251250
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
252-
multi_modal_runner_->prefill(llm::MultimodalInput{prompt->toStdString()});
251+
multi_modal_runner_->prefill(
252+
{llm::MultimodalInput{prompt->toStdString()}});
253253
return 0;
254254
}
255255
}
@@ -260,10 +260,10 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
260260
jint height,
261261
jint channels) {
262262
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
263-
return Error::InvalidArgument;
263+
return static_cast<jint>(Error::InvalidArgument);
264264
}
265265
if (image == nullptr) {
266-
return Error::InvalidArgument;
266+
return static_cast<jint>(Error::InvalidArgument);
267267
}
268268
std::vector<llm::Image> images;
269269
if (image == nullptr) {
@@ -278,7 +278,8 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
278278
image_data[i] = image_data_jint[i];
279279
}
280280
llm::Image image_runner{std::move(image_data), width, height, channels};
281-
multi_modal_runner_->prefill(llm::MultimodalInput{std::move(image_runner)});
281+
multi_modal_runner_->prefill(
282+
{llm::MultimodalInput{std::move(image_runner)}});
282283
}
283284

284285
return 0;
@@ -290,10 +291,10 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
290291
jint n_channels,
291292
jint n_samples) {
292293
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
293-
return Error::InvalidArgument;
294+
return static_cast<jint>(Error::InvalidArgument);
294295
}
295296
if (audio == nullptr) {
296-
return Error::InvalidArgument;
297+
return static_cast<jint>(Error::InvalidArgument);
297298
}
298299
auto audio_size = audio->size();
299300
std::vector<uint8_t> audio_data(audio_size);
@@ -304,7 +305,8 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
304305
audio_data[i] = audio_data_jint[i];
305306
}
306307
llm::RawAudio audio_input{audio_data, batch_size, n_channels, n_samples};
307-
multi_modal_runner_->prefill(llm::MultimodalInput{std::move(audio_input)});
308+
multi_modal_runner_->prefill(
309+
{llm::MultimodalInput{std::move(audio_input)}});
308310
}
309311
return 0;
310312
}

extension/llm/runner/irunner.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,17 @@ class ET_EXPERIMENTAL IRunner {
125125
std::function<void(const std::string&)> token_callback,
126126
std::function<void(const Stats&)> stats_callback) = 0;
127127

128+
/**
129+
* Prefill text inputs, for example to reload chat history.
130+
* @param prompt Text prompt to prefill.
131+
* @param config Configuration parameters (if non-zero num_bos and num_eos
132+
* used)
133+
* @return The error code. KV cache position is tracked internally in pos_.
134+
*/
135+
virtual ::executorch::runtime::Error prefill(
136+
const std::string& prompt,
137+
const GenerationConfig& config = {}) = 0;
138+
128139
/**
129140
* Stop the generation process.
130141
*/

extension/llm/runner/multimodal_runner.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ Error MultimodalRunner::load() {
6262
ET_LOG(Info, format, __VA_ARGS__); \
6363
}
6464

65-
Error MultimodalRunner::prefill(std::vector<MultimodalInput>& inputs) {
65+
Error MultimodalRunner::prefill(const std::vector<MultimodalInput>& inputs) {
6666
if (!is_loaded()) {
6767
ET_CHECK_OK_OR_RETURN_ERROR(load());
6868
}

extension/llm/runner/multimodal_runner.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ class ET_EXPERIMENTAL MultimodalRunner {
126126
* @return The error code. KV cache position is tracked internally in pos_.
127127
*/
128128
virtual ::executorch::runtime::Error prefill(
129-
std::vector<MultimodalInput>& inputs);
129+
const std::vector<MultimodalInput>& inputs);
130130

131131
inline void stop() {
132132
text_token_generator_->stop();

extension/llm/runner/text_llm_runner.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,13 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
104104
/**
105105
* Prefill text inputs, for example to reload chat history.
106106
* @param prompt Text prompt to prefill.
107-
* @param config Configuration parameters for text generation (e.g.,
108-
* max_new_tokens, temperature)
107+
* @param config Configuration parameters (if non-zero num_bos and num_eos
108+
* used)
109109
* @return The error code. KV cache position is tracked internally in pos_.
110110
*/
111111
::executorch::runtime::Error prefill(
112112
const std::string& prompt,
113-
const GenerationConfig& config);
113+
const GenerationConfig& config = {}) override;
114114

115115
/**
116116
* @brief Warms up the model with a sample prompt

0 commit comments

Comments
 (0)