Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,15 @@ public native int generate(
* @throws RuntimeException if the prefill failed
*/
public long prefillImages(int[] image, int width, int height, int channels, long startPos) {
long[] nativeResult = prefillImagesNative(image, width, height, channels, startPos);
long[] nativeResult = addImageInputNative(image, width, height, channels, startPos);
if (nativeResult[0] != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]);
}
return nativeResult[1];
}

// returns a tuple of (status, updated startPos)
private native long[] prefillImagesNative(
private native long[] addImageInputNative(
int[] image, int width, int height, int channels, long startPos);

/**
Expand All @@ -200,15 +200,18 @@ private native long[] prefillImagesNative(
* @throws RuntimeException if the prefill failed
*/
public long prefillPrompt(String prompt, long startPos, int bos, int eos) {
long[] nativeResult = prefillPromptNative(prompt, startPos, bos, eos);
long[] nativeResult = addTextInputNative(prompt, startPos, bos, eos);
if (nativeResult[0] != 0) {
throw new RuntimeException("Prefill failed with error code: " + nativeResult[0]);
}
return nativeResult[1];
}

// returns a tuple of (status, updated startPos)
private native long[] prefillPromptNative(String prompt, long startPos, int bos, int eos);
private native long[] addTextInputNative(String prompt, long startPos, int bos, int eos);

// returns the status code
private native int addAudioInputNative(int[] audio, int batch_size, int n_bins, int n_frames);

/**
* Generate tokens from the given prompt, starting from the given position.
Expand All @@ -217,6 +220,12 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) {
* @param seqLen The total sequence length, including the prompt tokens and new tokens.
* @param startPos The starting position in KV cache of the input in the LLM.
* @param callback callback object to receive results.
* @param echo indicate whether to echo the
* <p>/** Generate tokens from the given prompt, starting from the given position.
* @param prompt The text prompt to LLaVA.
* @param seqLen The total sequence length, including the prompt tokens and new tokens.
* @param startPos The starting position in KV cache of the input in the LLM.
* @param callback callback object to receive results.
* @param echo indicate whether to echo the input prompt or not.
* @return The error code.
*/
Expand Down
110 changes: 63 additions & 47 deletions extension/android/jni/jni_layer_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include <unordered_map>
#include <vector>

#include <executorch/examples/models/llava/runner/llava_runner.h>
#include <executorch/extension/llm/runner/image.h>
#include <executorch/extension/llm/runner/irunner.h>
#include <executorch/extension/llm/runner/llm_runner_helper.h>
Expand Down Expand Up @@ -122,7 +121,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
float temperature_ = 0.0f;
int model_type_category_;
std::unique_ptr<llm::IRunner> runner_;
std::unique_ptr<example::LlavaRunner> multi_modal_runner_;
std::unique_ptr<executorch::extension::llm::MultimodalRunner>
multi_modal_runner_;
std::vector<llm::MultimodalInput> prefill_inputs_;

public:
constexpr static auto kJavaDescriptor =
Expand Down Expand Up @@ -168,10 +169,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {

model_type_category_ = model_type_category;
if (model_type_category == MODEL_TYPE_CATEGORY_MULTIMODAL) {
multi_modal_runner_ = std::make_unique<example::LlavaRunner>(
multi_modal_runner_ = llm::create_multimodal_runner(
model_path->toStdString().c_str(),
tokenizer_path->toStdString().c_str(),
temperature);
llm::load_tokenizer(tokenizer_path->toStdString()));
} else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) {
std::optional<const std::string> data_path_str = data_path
? std::optional<const std::string>{data_path->toStdString()}
Expand Down Expand Up @@ -217,6 +217,9 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
jboolean echo) {
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
std::vector<llm::MultimodalInput> inputs = prefill_inputs_;
prefill_inputs_.clear();
inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()});
auto image_size = image->size();
std::vector<llm::Image> images;
if (image_size != 0) {
Expand All @@ -227,15 +230,18 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
image_data[i] = image_data_jint[i];
}
llm::Image image_runner{image_data, width, height, channels};
images.push_back(image_runner);
inputs.emplace_back(llm::MultimodalInput{std::move(image_runner)});
}
executorch::extension::llm::GenerationConfig config{
.echo = static_cast<bool>(echo),
.seq_len = seq_len,
.temperature = temperature_,
};
multi_modal_runner_->generate(
std::move(images),
prompt->toStdString(),
seq_len,
[callback](std::string result) { callback->onResult(result); },
[callback](const llm::Stats& result) { callback->onStats(result); },
echo);
std::move(inputs),
config,
[callback](const std::string& result) { callback->onResult(result); },
[callback](const llm::Stats& result) { callback->onStats(result); });
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
executorch::extension::llm::GenerationConfig config{
.echo = static_cast<bool>(echo),
Expand All @@ -254,47 +260,30 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
// Returns a tuple of (error, start_pos)
// Contract is valid within an AAR (JNI + corresponding Java code)
// If the first element is not Error::Ok, the other element is undefined.
facebook::jni::local_ref<jlongArray> prefill_prompt(
facebook::jni::local_ref<jlongArray> add_text_input(
facebook::jni::alias_ref<jstring> prompt,
jlong start_pos,
jint bos,
jint eos) {
prefill_inputs_.emplace_back(llm::MultimodalInput{prompt->toStdString()});
facebook::jni::local_ref<jlongArray> tuple_result =
facebook::jni::make_long_array(2);
if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
return tuple_result;
}

auto&& result = multi_modal_runner_->prefill_prompt(
prompt->toStdString(), start_pos, bos, eos);
tuple_result->pin()[0] = static_cast<jint>(Error::Ok);
if (result.ok()) {
tuple_result->pin()[1] = static_cast<jlong>(start_pos);
}
return tuple_result;
}

// Returns a tuple of (error, start_pos)
// Contract is valid within an AAR (JNI + corresponding Java code)
// If the first element is not Error::Ok, the other element is undefined.

facebook::jni::local_ref<jlongArray> prefill_images(
facebook::jni::local_ref<jlongArray> add_images_input(
facebook::jni::alias_ref<jintArray> image,
jint width,
jint height,
jint channels,
jlong start_pos) {
facebook::jni::local_ref<jlongArray> tuple_result =
facebook::jni::make_long_array(2);

if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) {
tuple_result->pin()[0] = static_cast<jint>(Error::NotSupported);
return tuple_result;
}

auto image_size = image->size();
std::vector<llm::Image> images;
auto image_size = image->size();
if (image_size != 0) {
std::vector<jint> image_data_jint(image_size);
std::vector<uint8_t> image_data(image_size);
Expand All @@ -303,30 +292,55 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
image_data[i] = image_data_jint[i];
}
llm::Image image_runner{image_data, width, height, channels};
images.push_back(image_runner);
prefill_inputs_.emplace_back(
llm::MultimodalInput{std::move(image_runner)});
}
// TODO(hsz): make start_pos a reference and update it here
jint result = static_cast<jint>(
multi_modal_runner_->prefill_images(images, start_pos));
tuple_result->pin()[0] = result;
tuple_result->pin()[1] = static_cast<jlong>(start_pos);

facebook::jni::local_ref<jlongArray> tuple_result =
facebook::jni::make_long_array(2);

tuple_result->pin()[0] = static_cast<jint>(Error::Ok);
return tuple_result;
}

// Returns the status code
jint add_audio_input(
facebook::jni::alias_ref<jintArray> audio,
jint batch_size,
jint n_bins,
jint n_frames) {
auto audio_size = audio->size();
if (audio_size != 0) {
std::vector<jint> audio_data_jint(audio_size);
std::vector<uint8_t> audio_data(audio_size);
audio->getRegion(0, audio_size, audio_data_jint.data());
for (int i = 0; i < audio_size; i++) {
audio_data[i] = audio_data_jint[i];
}
auto&& audio_input = llm::make_audio_input(
llm::Audio{audio_data, batch_size, n_bins, n_frames});
prefill_inputs_.emplace_back(audio_input);
}

return 0;
}

jint generate_from_pos(
facebook::jni::alias_ref<jstring> prompt,
jint seq_len,
jlong start_pos,
facebook::jni::alias_ref<ExecuTorchLlmCallbackJni> callback,
jboolean echo) {
if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) {
return static_cast<jint>(multi_modal_runner_->generate_from_pos(
prompt->toStdString(),
seq_len,
start_pos,
std::vector<llm::MultimodalInput> inputs = prefill_inputs_;
prefill_inputs_.clear();
inputs.emplace_back(llm::MultimodalInput{prompt->toStdString()});
return static_cast<jint>(multi_modal_runner_->generate(
inputs,
llm::GenerationConfig{
.echo = static_cast<bool>(echo), .seq_len = seq_len},
[callback](const std::string& result) { callback->onResult(result); },
[callback](const llm::Stats& stats) { callback->onStats(stats); },
echo));
[callback](const llm::Stats& stats) { callback->onStats(stats); }));
} else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) {
executorch::extension::llm::GenerationConfig config{
.echo = static_cast<bool>(echo),
Expand Down Expand Up @@ -367,9 +381,11 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
makeNativeMethod("stop", ExecuTorchLlmJni::stop),
makeNativeMethod("load", ExecuTorchLlmJni::load),
makeNativeMethod(
"prefillImagesNative", ExecuTorchLlmJni::prefill_images),
"addImageInputNative", ExecuTorchLlmJni::add_images_input),
makeNativeMethod(
"addTextInputNative", ExecuTorchLlmJni::add_text_input),
makeNativeMethod(
"prefillPromptNative", ExecuTorchLlmJni::prefill_prompt),
"addAudioInputNative", ExecuTorchLlmJni::add_audio_input),
makeNativeMethod(
"generateFromPos", ExecuTorchLlmJni::generate_from_pos),
});
Expand Down
Loading