Skip to content
78 changes: 35 additions & 43 deletions extension/android/jni/jni_layer_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include <unordered_map>
#include <vector>

#include <executorch/examples/models/llava/runner/llava_runner.h>
#include <executorch/extension/llm/runner/image.h>
#include <executorch/extension/llm/runner/irunner.h>
#include <executorch/extension/llm/runner/llm_runner_helper.h>
Expand Down Expand Up @@ -122,7 +121,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
float temperature_ = 0.0f;
int model_type_category_;
std::unique_ptr<llm::IRunner> runner_;
std::unique_ptr<example::LlavaRunner> multi_modal_runner_;
std::unique_ptr<executorch::extension::llm::MultimodalRunner>
multi_modal_runner_;
std::vector<llm::MultimodalInput> prefill_inputs_;

public:
constexpr static auto kJavaDescriptor =
Expand Down Expand Up @@ -168,10 +169,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {

model_type_category_ = model_type_category;
if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) {
multi_modal_runner_ = std::make_unique<example::LlavaRunner>(
multi_modal_runner_ = llm::create_multimodal_runner(
model_path->toStdString().c_str(),
tokenizer_path->toStdString().c_str(),
temperature);
llm::load_tokenizer(tokenizer_path->toStdString()));
} else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
std::optional<const std::string> data_path_str = data_path
? std::optional<const std::string>{data_path->toStdString()}
Expand Down Expand Up @@ -217,6 +217,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
jboolean echo) {
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
std::vector<llm::MultimodalInput> inputs = prefill_inputs_;
prefill_inputs_.clear();
inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()});
auto image_size = image->size();
std::vector<llm::Image> images;
if (image_size != 0) {
Expand All @@ -227,15 +230,18 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
image_data[i] = image_data_jint[i];
}
llm::Image image_runner{image_data, width, height, channels};
images.push_back(image_runner);
inputs.emplace_back(llm::MultimodalInput{std::move(image_runner)});
}
executorch::extension::llm::GenerationConfig config{
.echo = static_cast<bool>(echo),
.seq_len = seq_len,
.temperature = temperature_,
};
multi_modal_runner_->generate(
std::move(images),
prompt->toStdString(),
seq_len,
[callback](std::string result) { callback->onResult(result); },
[callback](const llm::Stats& result) { callback->onStats(result); },
echo);
std::move(inputs),
config,
[callback](const std::string& result) { callback->onResult(result); },
[callback](const llm::Stats& result) { callback->onStats(result); });
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
executorch::extension::llm::GenerationConfig config{
.echo = static_cast<bool>(echo),
Expand All @@ -259,19 +265,10 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
jlong start_pos,
jint bos,
jint eos) {
prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()});
facebook::jni::local_ref<jlongArray> tuple_result =
facebook::jni::make_long_array(2);
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
return tuple_result;
}

auto&& result = multi_modal_runner_->prefill_prompt(
prompt->toStdString(), start_pos, bos, eos);
tuple_result->pin()[0] = static_cast<jint>(Error::Ok);
if (result.ok()) {
tuple_result->pin()[1] = static_cast<jlong>(start_pos);
}
return tuple_result;
}

Expand All @@ -285,16 +282,8 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
jint height,
jint channels,
jlong start_pos) {
facebook::jni::local_ref<jlongArray> tuple_result =
facebook::jni::make_long_array(2);

if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
return tuple_result;
}

auto image_size = image->size();
std::vector<llm::Image> images;
auto image_size = image->size();
if (image_size != 0) {
std::vector<jint> image_data_jint(image_size);
std::vector<uint8_t> image_data(image_size);
Expand All @@ -303,13 +292,14 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
image_data[i] = image_data_jint[i];
}
llm::Image image_runner{image_data, width, height, channels};
images.push_back(image_runner);
prefill_inputs_.emplace_back(
llm::MultimodalInput{std::move(image_runner)});
}
// TODO(hsz): make start_pos a reference and update it here
jint result = static_cast<jint>(
multi_modal_runner_->prefill_images(images, start_pos));
tuple_result->pin()[0] = result;
tuple_result->pin()[1] = static_cast<jlong>(start_pos);

facebook::jni::local_ref<jlongArray> tuple_result =
facebook::jni::make_long_array(2);

tuple_result->pin()[0] = static_cast<jint>(Error::Ok);
return tuple_result;
}

Expand All @@ -320,13 +310,15 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
jboolean echo) {
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
return static_cast<jint>(multi_modal_runner_->generate_from_pos(
prompt->toStdString(),
seq_len,
start_pos,
std::vector<llm::MultimodalInput> inputs = prefill_inputs_;
prefill_inputs_.clear();
inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()});
return static_cast<jint>(multi_modal_runner_->generate(
inputs,
llm::GenerationConfig{
.echo = static_cast<bool>(echo), .seq_len = seq_len},
[callback](const std::string& result) { callback->onResult(result); },
[callback](const llm::Stats& stats) { callback->onStats(stats); },
echo));
[callback](const llm::Stats& stats) { callback->onStats(stats); }));
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
executorch::extension::llm::GenerationConfig config{
.echo = static_cast<bool>(echo),
Expand Down
4 changes: 2 additions & 2 deletions extension/llm/runner/llm_runner_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ get_llm_metadata(tokenizers::Tokenizer* tokenizer, Module* module) {
if (!method_names.count(llm::kMaxSeqLen)) {
ET_LOG(
Error,
"Required metadata method %s not found in model",
"Required metadata method %s not found in model. Bypass",
llm::kMaxSeqLen);
return ::executorch::runtime::Error::InvalidArgument;
// return ::executorch::runtime::Error::InvalidArgument;
}

for (auto& pair : metadata) {
Expand Down
7 changes: 7 additions & 0 deletions extension/llm/runner/multimodal_prefiller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ MultimodalPrefiller::MultimodalPrefiller(
Result<uint64_t> MultimodalPrefiller::prefill(
const MultimodalInput& input,
int64_t& start_pos) {
ET_LOG(Error, "Here 000000000000000000000000000000000000000000000000000000000000000000 %d", (int) (start_pos));
ET_LOG(Error, "Here 000000000000000000000000000000000000000000000000000000000000000000 %s", input.get_text().c_str());
// 1. Run encoder model.
::executorch::runtime::EValue encoder_output;
if (input.is_image()) {
Expand Down Expand Up @@ -73,12 +75,14 @@ Result<uint64_t> MultimodalPrefiller::prefill(
auto& text = input.get_text();
std::vector<uint64_t> tokens =
ET_UNWRAP_TOKENIZER(tokenizer_->encode(text));
ET_LOG(Error, "Here 666666666666666666666666666666666666666666666666666666666666666666666666666666");

auto text_tensor = executorch::extension::from_blob(
tokens.data(),
{1, static_cast<aten::SizesType>(tokens.size())},
::executorch::aten::ScalarType::Long);

ET_LOG(Error, "Here 555555555555555555555555555555555555555555555555555555555555555555555555555");
// Run text encoder (token embeddings)
auto token_embedding_outputs =
ET_UNWRAP(module_->execute(kTokenEmbeddingMethod, text_tensor));
Expand All @@ -89,6 +93,7 @@ Result<uint64_t> MultimodalPrefiller::prefill(
// For any other input types, return error
return ::executorch::runtime::Error::NotSupported;
}
ET_LOG(Error, "Here 000000000000000000000000000000000000000000000000000000000000000000000000000");

// 2. Run decoder model for prefill.
// `cache_position` goes from start_pos to start_pos + encoder_output.size(1).
Expand All @@ -107,6 +112,7 @@ Result<uint64_t> MultimodalPrefiller::prefill(
cache_positions.data(),
{static_cast<int>(seq_len)},
executorch::aten::ScalarType::Long);
ET_LOG(Error, "Here 111111111111111111111111111111111111111111111111111111111111111111111111111111111111111");
auto prefill_result = module_->execute(
kTextModelMethod, {cache_position_tensor, encoder_output});
if (prefill_result.error() != ::executorch::runtime::Error::Ok) {
Expand All @@ -121,6 +127,7 @@ Result<uint64_t> MultimodalPrefiller::prefill(
return ::executorch::runtime::Error::InvalidState;
}
auto outputs_res = prefill_outputs[0].toTensor();
ET_LOG(Error, "Here 222222222222222222222222222222222222222222222222222222222222222222222222222222");

// Update start_pos, tracking the current cache position.
start_pos += seq_len;
Expand Down
8 changes: 6 additions & 2 deletions extension/llm/runner/multimodal_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ Error MultimodalRunner::load() {
// Don't print with the same priority during warmup
#define RUNNER_ET_LOG(warmup, format, ...) \
if (warmup) { \
ET_LOG(Debug, format, __VA_ARGS__); \
ET_LOG(Error, format, __VA_ARGS__); \
} else { \
ET_LOG(Info, format, __VA_ARGS__); \
ET_LOG(Error, format, __VA_ARGS__); \
}

Error MultimodalRunner::generate(
Expand Down Expand Up @@ -104,16 +104,20 @@ Error MultimodalRunner::generate(

uint64_t prefill_next_token = 0;
// Process multimodal inputs in order
ET_LOG(Error, "0000000000000000000000000000000000000000000000000000SIZE%d", inputs.size());
for (const MultimodalInput& input : inputs) {
ET_LOG(Error, "00000000000000000000000000000000123321451345143100");
prefill_next_token = ET_UNWRAP(multimodal_prefiller_->prefill(input, pos_));
}
ET_LOG(Error, "1111111111111111111111111111111111111111111111111111");

stats_->first_token_ms = time_in_ms();
stats_->prompt_eval_end_ms = time_in_ms();
stats_->num_prompt_tokens = pos_;

wrapped_callback(ET_UNWRAP_TOKENIZER(
tokenizer_->decode(prefill_next_token, prefill_next_token)));
ET_LOG(Info, "2222222222222222222222222222222222222222222222222222");

RUNNER_ET_LOG(
config.warming,
Expand Down
Loading