Skip to content

Commit de0ff26

Browse files
authored
Jni layer use multimodal runner (pytorch#13825)
Use multimodal runner instead of llava runner
1 parent 823dea1 commit de0ff26

File tree

1 file changed

+35
-43
lines changed

1 file changed

+35
-43
lines changed

extension/android/jni/jni_layer_llama.cpp

Lines changed: 35 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#include <unordered_map>
1414
#include <vector>
1515

16-
#include <executorch/examples/models/llava/runner/llava_runner.h>
1716
#include <executorch/extension/llm/runner/image.h>
1817
#include <executorch/extension/llm/runner/irunner.h>
1918
#include <executorch/extension/llm/runner/llm_runner_helper.h>
@@ -122,7 +121,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
122121
float temperature_ = 0.0f;
123122
int model_type_category_;
124123
std::unique_ptr<llm::IRunner> runner_;
125-
std::unique_ptr<example::LlavaRunner> multi_modal_runner_;
124+
std::unique_ptr<executorch::extension::llm::MultimodalRunner>
125+
multi_modal_runner_;
126+
std::vector<llm::MultimodalInput> prefill_inputs_;
126127

127128
public:
128129
constexpr static auto kJavaDescriptor =
@@ -168,10 +169,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
168169

169170
model_type_category_ = model_type_category;
170171
if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) {
171-
multi_modal_runner_ = std::make_unique<example::LlavaRunner>(
172+
multi_modal_runner_ = llm::create_multimodal_runner(
172173
model_path->toStdString().c_str(),
173-
tokenizer_path->toStdString().c_str(),
174-
temperature);
174+
llm::load_tokenizer(tokenizer_path->toStdString()));
175175
} else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
176176
std::optional<const std::string> data_path_str = data_path
177177
? std::optional<const std::string>{data_path->toStdString()}
@@ -217,6 +217,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
217217
facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
218218
jboolean echo) {
219219
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
220+
std::vector<llm::MultimodalInput> inputs = prefill_inputs_;
221+
prefill_inputs_.clear();
222+
inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()});
220223
auto image_size = image->size();
221224
std::vector<llm::Image> images;
222225
if (image_size != 0) {
@@ -227,15 +230,18 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
227230
image_data[i] = image_data_jint[i];
228231
}
229232
llm::Image image_runner{image_data, width, height, channels};
230-
images.push_back(image_runner);
233+
inputs.emplace_back(llm::MultimodalInput{std::move(image_runner)});
231234
}
235+
executorch::extension::llm::GenerationConfig config{
236+
.echo = static_cast<bool>(echo),
237+
.seq_len = seq_len,
238+
.temperature = temperature_,
239+
};
232240
multi_modal_runner_->generate(
233-
std::move(images),
234-
prompt->toStdString(),
235-
seq_len,
236-
[callback](std::string result) { callback->onResult(result); },
237-
[callback](const llm::Stats& result) { callback->onStats(result); },
238-
echo);
241+
std::move(inputs),
242+
config,
243+
[callback](const std::string& result) { callback->onResult(result); },
244+
[callback](const llm::Stats& result) { callback->onStats(result); });
239245
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
240246
executorch::extension::llm::GenerationConfig config{
241247
.echo = static_cast<bool>(echo),
@@ -259,19 +265,10 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
259265
jlong start_pos,
260266
jint bos,
261267
jint eos) {
268+
prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()});
262269
facebook::jni::local_ref<jlongArray> tuple_result =
263270
facebook::jni::make_long_array(2);
264-
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
265-
tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
266-
return tuple_result;
267-
}
268-
269-
auto&& result = multi_modal_runner_->prefill_prompt(
270-
prompt->toStdString(), start_pos, bos, eos);
271271
tuple_result->pin()[0] = static_cast<jint>(Error::Ok);
272-
if (result.ok()) {
273-
tuple_result->pin()[1] = static_cast<jlong>(start_pos);
274-
}
275272
return tuple_result;
276273
}
277274

@@ -285,16 +282,8 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
285282
jint height,
286283
jint channels,
287284
jlong start_pos) {
288-
facebook::jni::local_ref<jlongArray> tuple_result =
289-
facebook::jni::make_long_array(2);
290-
291-
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
292-
tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
293-
return tuple_result;
294-
}
295-
296-
auto image_size = image->size();
297285
std::vector<llm::Image> images;
286+
auto image_size = image->size();
298287
if (image_size != 0) {
299288
std::vector<jint> image_data_jint(image_size);
300289
std::vector<uint8_t> image_data(image_size);
@@ -303,13 +292,14 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
303292
image_data[i] = image_data_jint[i];
304293
}
305294
llm::Image image_runner{image_data, width, height, channels};
306-
images.push_back(image_runner);
295+
prefill_inputs_.emplace_back(
296+
llm::MultimodalInput{std::move(image_runner)});
307297
}
308-
// TODO(hsz): make start_pos a reference and update it here
309-
jint result = static_cast<jint>(
310-
multi_modal_runner_->prefill_images(images, start_pos));
311-
tuple_result->pin()[0] = result;
312-
tuple_result->pin()[1] = static_cast<jlong>(start_pos);
298+
299+
facebook::jni::local_ref<jlongArray> tuple_result =
300+
facebook::jni::make_long_array(2);
301+
302+
tuple_result->pin()[0] = static_cast<jint>(Error::Ok);
313303
return tuple_result;
314304
}
315305

@@ -320,13 +310,15 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
320310
facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
321311
jboolean echo) {
322312
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
323-
return static_cast<jint>(multi_modal_runner_->generate_from_pos(
324-
prompt->toStdString(),
325-
seq_len,
326-
start_pos,
313+
std::vector<llm::MultimodalInput> inputs = prefill_inputs_;
314+
prefill_inputs_.clear();
315+
inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()});
316+
return static_cast<jint>(multi_modal_runner_->generate(
317+
inputs,
318+
llm::GenerationConfig{
319+
.echo = static_cast<bool>(echo), .seq_len = seq_len},
327320
[callback](const std::string& result) { callback->onResult(result); },
328-
[callback](const llm::Stats& stats) { callback->onStats(stats); },
329-
echo));
321+
[callback](const llm::Stats& stats) { callback->onStats(stats); }));
330322
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
331323
executorch::extension::llm::GenerationConfig config{
332324
.echo = static_cast<bool>(echo),

0 commit comments

Comments
 (0)