Skip to content

Commit 43d8e5e

Browse files
committed
Prefill
1 parent 0f868d7 commit 43d8e5e

File tree

3 files changed

+43
-17
lines changed

3 files changed

+43
-17
lines changed

extension/android/jni/jni_layer_llama.cpp

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
121121
float temperature_ = 0.0f;
122122
int model_type_category_;
123123
std::unique_ptr<llm::IRunner> runner_;
124-
std::unique_ptr<executorch::extension::llm::MultimodalRunner> multi_modal_runner_;
124+
std::unique_ptr<executorch::extension::llm::MultimodalRunner>
125+
multi_modal_runner_;
126+
std::vector<llm::MultimodalInput> prefill_inputs_;
125127

126128
public:
127129
constexpr static auto kJavaDescriptor =
@@ -215,6 +217,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
215217
facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
216218
jboolean echo) {
217219
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
220+
std::vector<llm::MultimodalInput> inputs = prefill_inputs_;
221+
prefill_inputs_.clear();
222+
inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()});
218223
auto image_size = image->size();
219224
std::vector<llm::Image> images;
220225
if (image_size != 0) {
@@ -225,15 +230,18 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
225230
image_data[i] = image_data_jint[i];
226231
}
227232
llm::Image image_runner{image_data, width, height, channels};
228-
images.push_back(image_runner);
233+
inputs.emplace_back(llm::MultimodalInput{std::move(image_runner)});
229234
}
235+
executorch::extension::llm::GenerationConfig config{
236+
.echo = static_cast<bool>(echo),
237+
.seq_len = seq_len,
238+
.temperature = temperature_,
239+
};
230240
multi_modal_runner_->generate(
231-
std::move(images),
232-
prompt->toStdString(),
233-
seq_len,
234-
[callback](std::string result) { callback->onResult(result); },
235-
[callback](const llm::Stats& result) { callback->onStats(result); },
236-
echo);
241+
std::move(inputs),
242+
config,
243+
[callback](const std::string& result) { callback->onResult(result); },
244+
[callback](const llm::Stats& result) { callback->onStats(result); });
237245
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
238246
executorch::extension::llm::GenerationConfig config{
239247
.echo = static_cast<bool>(echo),
@@ -257,9 +265,10 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
257265
jlong start_pos,
258266
jint bos,
259267
jint eos) {
268+
prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()});
260269
facebook::jni::local_ref<jlongArray> tuple_result =
261270
facebook::jni::make_long_array(2);
262-
tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
271+
tuple_result->pin()[0] = static_cast<jint>(Error::Ok);
263272
return tuple_result;
264273
}
265274

@@ -273,10 +282,24 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
273282
jint height,
274283
jint channels,
275284
jlong start_pos) {
285+
std::vector<llm::Image> images;
286+
auto image_size = image->size();
287+
if (image_size != 0) {
288+
std::vector<jint> image_data_jint(image_size);
289+
std::vector<uint8_t> image_data(image_size);
290+
image->getRegion(0, image_size, image_data_jint.data());
291+
for (int i = 0; i < image_size; i++) {
292+
image_data[i] = image_data_jint[i];
293+
}
294+
llm::Image image_runner{image_data, width, height, channels};
295+
prefill_inputs_.emplace_back(
296+
llm::MultimodalInput{std::move(image_runner)});
297+
}
298+
276299
facebook::jni::local_ref<jlongArray> tuple_result =
277300
facebook::jni::make_long_array(2);
278301

279-
tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
302+
tuple_result->pin()[0] = static_cast<jint>(Error::Ok);
280303
return tuple_result;
281304
}
282305

@@ -287,10 +310,13 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
287310
facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
288311
jboolean echo) {
289312
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
290-
313+
std::vector<llm::MultimodalInput> inputs = prefill_inputs_;
314+
prefill_inputs_.clear();
315+
inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()});
291316
return static_cast<jint>(multi_modal_runner_->generate(
292-
std::vector<llm::MultimodalInput>{llm::MultimodalInput{prompt->toStdString()}},
293-
llm::GenerationConfig {.echo = static_cast<bool>(echo), .seq_len = seq_len},
317+
inputs,
318+
llm::GenerationConfig{
319+
.echo = static_cast<bool>(echo), .seq_len = seq_len},
294320
[callback](const std::string& result) { callback->onResult(result); },
295321
[callback](const llm::Stats& stats) { callback->onStats(stats); }));
296322
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {

extension/llm/runner/multimodal_runner.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ Error MultimodalRunner::load() {
6565
Error MultimodalRunner::generate(
6666
const std::vector<MultimodalInput>& inputs,
6767
const GenerationConfig& config,
68-
std::function<void(const std::string&)>& token_callback,
69-
std::function<void(const Stats&)>& stats_callback) {
68+
std::function<void(const std::string&)> token_callback,
69+
std::function<void(const Stats&)> stats_callback) {
7070
if (inputs.empty()) {
7171
ET_LOG(Error, "MultimodalInput vector cannot be empty");
7272
return Error::InvalidArgument;

extension/llm/runner/multimodal_runner.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ class ET_EXPERIMENTAL MultimodalRunner {
116116
virtual ::executorch::runtime::Error generate(
117117
const std::vector<MultimodalInput>& inputs,
118118
const GenerationConfig& config,
119-
std::function<void(const std::string&)>& token_callback,
120-
std::function<void(const Stats&)>& stats_callback);
119+
std::function<void(const std::string&)> token_callback,
120+
std::function<void(const Stats&)> stats_callback);
121121

122122
inline void stop() {
123123
text_token_generator_->stop();

0 commit comments

Comments
 (0)