Skip to content

Commit 63e407e

Browse files
committed
test
1 parent 463c4b5 commit 63e407e

File tree

1 file changed

+16
-49
lines changed

1 file changed

+16
-49
lines changed

extension/android/jni/jni_layer_llama.cpp

Lines changed: 16 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
#include <unordered_map>
1414
#include <vector>
1515

16-
#include <executorch/examples/models/llama/runner/runner.h>
17-
#include <executorch/examples/models/llava/runner/llava_runner.h>
1816
#include <executorch/extension/llm/runner/image.h>
1917
#include <executorch/extension/llm/runner/irunner.h>
18+
#include <executorch/extension/llm/runner/llm_runner_helper.h>
19+
#include <executorch/extension/llm/runner/multimodal_input.h>
20+
#include <executorch/extension/llm/runner/multimodal_runner.h>
21+
#include <executorch/extension/llm/runner/text_llm_runner.h>
2022
#include <executorch/runtime/platform/log.h>
2123
#include <executorch/runtime/platform/platform.h>
2224
#include <executorch/runtime/platform/runtime.h>
@@ -119,7 +121,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
119121
float temperature_ = 0.0f;
120122
int model_type_category_;
121123
std::unique_ptr<llm::IRunner> runner_;
122-
std::unique_ptr<example::LlavaRunner> multi_modal_runner_;
124+
std::unique_ptr<executorch::extension::llm::MultimodalRunner> multi_modal_runner_;
123125

124126
public:
125127
constexpr static auto kJavaDescriptor =
@@ -165,19 +167,16 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
165167

166168
model_type_category_ = model_type_category;
167169
if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) {
168-
multi_modal_runner_ = std::make_unique<example::LlavaRunner>(
170+
multi_modal_runner_ = llm::create_multimodal_runner(
169171
model_path->toStdString().c_str(),
170-
tokenizer_path->toStdString().c_str(),
171-
temperature);
172+
llm::load_tokenizer(tokenizer_path->toStdString()));
172173
} else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
173174
std::optional<const std::string> data_path_str = data_path
174175
? std::optional<const std::string>{data_path->toStdString()}
175176
: std::nullopt;
176-
// TODO(larryliu0820): Use the API in text_llm_runner.h to create the
177-
// runner.
178-
runner_ = example::create_llama_runner(
177+
runner_ = executorch::extension::llm::create_text_llm_runner(
179178
model_path->toStdString(),
180-
tokenizer_path->toStdString(),
179+
llm::load_tokenizer(tokenizer_path->toStdString()),
181180
data_path_str);
182181
#if defined(EXECUTORCH_BUILD_QNN)
183182
} else if (model_type_category == MODEL_TYPE_QNN_LLAMA) {
@@ -260,17 +259,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
260259
jint eos) {
261260
facebook::jni::local_ref<jlongArray> tuple_result =
262261
facebook::jni::make_long_array(2);
263-
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
264-
tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
265-
return tuple_result;
266-
}
267-
268-
auto&& result = multi_modal_runner_->prefill_prompt(
269-
prompt->toStdString(), start_pos, bos, eos);
270-
tuple_result->pin()[0] = static_cast<jint>(Error::Ok);
271-
if (result.ok()) {
272-
tuple_result->pin()[1] = static_cast<jlong>(start_pos);
273-
}
262+
tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
274263
return tuple_result;
275264
}
276265

@@ -287,28 +276,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
287276
facebook::jni::local_ref<jlongArray> tuple_result =
288277
facebook::jni::make_long_array(2);
289278

290-
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
291-
tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
292-
return tuple_result;
293-
}
294-
295-
auto image_size = image->size();
296-
std::vector<llm::Image> images;
297-
if (image_size != 0) {
298-
std::vector<jint> image_data_jint(image_size);
299-
std::vector<uint8_t> image_data(image_size);
300-
image->getRegion(0, image_size, image_data_jint.data());
301-
for (int i = 0; i < image_size; i++) {
302-
image_data[i] = image_data_jint[i];
303-
}
304-
llm::Image image_runner{image_data, width, height, channels};
305-
images.push_back(image_runner);
306-
}
307-
// TODO(hsz): make start_pos a reference and update it here
308-
jint result = static_cast<jint>(
309-
multi_modal_runner_->prefill_images(images, start_pos));
310-
tuple_result->pin()[0] = result;
311-
tuple_result->pin()[1] = static_cast<jlong>(start_pos);
279+
tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
312280
return tuple_result;
313281
}
314282

@@ -319,13 +287,12 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
319287
facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
320288
jboolean echo) {
321289
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
322-
return static_cast<jint>(multi_modal_runner_->generate_from_pos(
323-
prompt->toStdString(),
324-
seq_len,
325-
start_pos,
290+
291+
return static_cast<jint>(multi_modal_runner_->generate(
292+
std::vector<llm::MultimodalInput>{llm::MultimodalInput{prompt->toStdString()}},
293+
llm::GenerationConfig {.echo = static_cast<bool>(echo), .seq_len = seq_len},
326294
[callback](const std::string& result) { callback->onResult(result); },
327-
[callback](const llm::Stats& stats) { callback->onStats(stats); },
328-
echo));
295+
[callback](const llm::Stats& stats) { callback->onStats(stats); }));
329296
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
330297
executorch::extension::llm::GenerationConfig config{
331298
.echo = static_cast<bool>(echo),

0 commit comments

Comments
 (0)