Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions examples/models/llava/runner/llava_image_prefiller.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@

namespace example {

class ET_EXPERIMENTAL LlavaImagePrefiller
: public ::executorch::extension::llm::ImagePrefiller {
class ET_EXPERIMENTAL LlavaImagePrefiller {
public:
explicit LlavaImagePrefiller(::executorch::extension::Module* module)
: ImagePrefiller(module){};
: module_(module) {}

/**
* Prefill an LLM Module with the given image input.
* @param image The image input to LLaVa.
Expand All @@ -28,7 +28,7 @@ class ET_EXPERIMENTAL LlavaImagePrefiller
*/
inline ::executorch::runtime::Result<executorch::aten::Tensor> prefill(
::executorch::extension::llm::Image& image,
int64_t& start_pos) override {
int64_t& start_pos) {
auto image_tensor = executorch::extension::from_blob(
image.data.data(),
{3, image.height, image.width},
Expand Down Expand Up @@ -59,7 +59,7 @@ class ET_EXPERIMENTAL LlavaImagePrefiller
* Load the Module for image prefill purpose.
* @return The error code.
*/
inline ::executorch::runtime::Error load() override {
inline ::executorch::runtime::Error load() {
if (is_method_loaded()) {
return ::executorch::runtime::Error::Ok;
}
Expand All @@ -72,7 +72,7 @@ class ET_EXPERIMENTAL LlavaImagePrefiller
* Check if the required methods in the Module is loaded.
* @return True if the Module is loaded, false otherwise.
*/
inline bool is_method_loaded() override {
inline bool is_method_loaded() {
::executorch::runtime::Result<std::unordered_set<std::string>> methods_res =
module_->method_names();
if (methods_res.error() != ::executorch::runtime::Error::Ok) {
Expand All @@ -88,16 +88,19 @@ class ET_EXPERIMENTAL LlavaImagePrefiller
ET_CHECK_MSG(
methods_exist,
"Missing required methods (%s, %s) in the model",
kImageEncoderMethod.c_str(),
kTextModelMethod.c_str());
kImageEncoderMethod,
kTextModelMethod);
}
bool methods_loaded = module_->is_method_loaded(kImageEncoderMethod) &&
module_->is_method_loaded(kTextModelMethod);
return methods_loaded;
}

inline static const std::string kImageEncoderMethod = "image_encoder";
inline static const std::string kTextModelMethod = "text_model";
inline static constexpr auto kImageEncoderMethod = "image_encoder";
inline static constexpr auto kTextModelMethod = "text_model";

private:
::executorch::extension::Module* module_;
};

} // namespace example
68 changes: 55 additions & 13 deletions examples/models/llava/runner/llava_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,50 @@
// processing logic.
#pragma once

#include <executorch/examples/models/llava/runner/llava_image_prefiller.h>
#include <executorch/extension/llm/runner/image.h>
#include <executorch/extension/llm/runner/io_manager/io_manager.h>
#include <executorch/extension/llm/runner/irunner.h>
#include <executorch/extension/llm/runner/stats.h>
#include <executorch/extension/llm/runner/text_decoder_runner.h>
#include <executorch/extension/llm/runner/text_prefiller.h>
#include <executorch/extension/llm/runner/text_token_generator.h>
#include <executorch/extension/module/module.h>
#include <cstdint>
#include <functional>
#include <memory>
#include <string>
#include <type_traits>
#include <unordered_map>

#include <executorch/extension/llm/runner/multimodal_runner.h>

namespace example {

class ET_EXPERIMENTAL LlavaRunner
: public ::executorch::extension::llm::MultimodalRunner {
using executorch::extension::Module;
using executorch::extension::llm::ImagePrefiller;
using executorch::extension::llm::IOManager;
using executorch::extension::llm::Stats;
using executorch::extension::llm::TextDecoderRunner;
using executorch::extension::llm::TextPrefiller;
using executorch::extension::llm::TextTokenGenerator;

class ET_EXPERIMENTAL LlavaRunner {
public:
explicit LlavaRunner(
const std::string& model_path,
const std::string& tokenizer_path,
const float temperature = 0.8f)
: MultimodalRunner(model_path, tokenizer_path, temperature){};
: temperature_(temperature),
module_(std::make_unique<Module>(model_path, Module::LoadMode::File)),
io_manager_(std::make_unique<IOManager>()),
tokenizer_path_(tokenizer_path) {
ET_LOG(
Info,
"Creating Llava runner: model_path=%s, tokenizer_path=%s",
model_path.c_str(),
tokenizer_path.c_str());
}

bool is_loaded() override;
bool is_loaded();

::executorch::runtime::Error load() override;
::executorch::runtime::Error load();

::executorch::runtime::Error generate(
std::vector<::executorch::extension::llm::Image> images,
Expand All @@ -41,17 +62,17 @@ class ET_EXPERIMENTAL LlavaRunner
std::function<void(const std::string&)> token_callback = {},
std::function<void(const ::executorch::extension::llm::Stats&)>
stats_callback = {},
bool echo = true) override;
bool echo = true);

::executorch::runtime::Error prefill_images(
std::vector<::executorch::extension::llm::Image>& images,
int64_t& start_pos) override;
int64_t& start_pos);

::executorch::runtime::Result<uint64_t> prefill_prompt(
const std::string& prompt,
int64_t& start_pos,
int8_t bos = 0,
int8_t eos = 0) override;
int8_t eos = 0);

::executorch::runtime::Error generate_from_pos(
const std::string& prompt,
Expand All @@ -60,9 +81,30 @@ class ET_EXPERIMENTAL LlavaRunner
std::function<void(const std::string&)> token_callback = {},
std::function<void(const ::executorch::extension::llm::Stats&)>
stats_callback = {},
bool echo = true) override;
bool echo = true);

inline void stop() {
text_token_generator_->stop();
}

private:
// metadata
float temperature_;

// model
std::unordered_set<std::string> model_methods_;
std::unique_ptr<Module> module_;
std::unique_ptr<TextDecoderRunner> text_decoder_runner_;
std::unique_ptr<TextPrefiller> text_prefiller_;
std::unique_ptr<LlavaImagePrefiller> image_prefiller_;
std::unique_ptr<IOManager> io_manager_;
std::unique_ptr<TextTokenGenerator> text_token_generator_;
std::string tokenizer_path_;
std::unique_ptr<::tokenizers::Tokenizer> tokenizer_;

// stats
Stats stats_;

inline static const char* kPresetPrompt =
"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: ";
};
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llava/runner/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def define_common_targets():
"//executorch/kernels/quantized:generated_lib",
"//executorch/runtime/core/exec_aten:lib",
"//executorch/runtime/core/exec_aten/util:tensor_util",
"//executorch/configurations:optimized_native_cpu_ops",
"//executorch/configurations:optimized_native_cpu_ops",
"//executorch/extension/llm/custom_ops:custom_ops",
"//pytorch/tokenizers:llama2c_tokenizer",
],
Expand Down
2 changes: 1 addition & 1 deletion extension/android/jni/jni_layer_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
float temperature_ = 0.0f;
int model_type_category_;
std::unique_ptr<llm::IRunner> runner_;
std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
std::unique_ptr<example::LlavaRunner> multi_modal_runner_;

public:
constexpr static auto kJavaDescriptor =
Expand Down
Loading