diff --git a/examples/models/llava/runner/llava_image_prefiller.h b/examples/models/llava/runner/llava_image_prefiller.h index 762a28d0d07..972db2998b8 100644 --- a/examples/models/llava/runner/llava_image_prefiller.h +++ b/examples/models/llava/runner/llava_image_prefiller.h @@ -15,11 +15,11 @@ namespace example { -class ET_EXPERIMENTAL LlavaImagePrefiller - : public ::executorch::extension::llm::ImagePrefiller { +class ET_EXPERIMENTAL LlavaImagePrefiller { public: explicit LlavaImagePrefiller(::executorch::extension::Module* module) - : ImagePrefiller(module){}; + : module_(module) {} + /** * Prefill an LLM Module with the given image input. * @param image The image input to LLaVa. @@ -28,7 +28,7 @@ class ET_EXPERIMENTAL LlavaImagePrefiller */ inline ::executorch::runtime::Result prefill( ::executorch::extension::llm::Image& image, - int64_t& start_pos) override { + int64_t& start_pos) { auto image_tensor = executorch::extension::from_blob( image.data.data(), {3, image.height, image.width}, @@ -59,7 +59,7 @@ class ET_EXPERIMENTAL LlavaImagePrefiller * Load the Module for image prefill purpose. * @return The error code. */ - inline ::executorch::runtime::Error load() override { + inline ::executorch::runtime::Error load() { if (is_method_loaded()) { return ::executorch::runtime::Error::Ok; } @@ -72,7 +72,7 @@ class ET_EXPERIMENTAL LlavaImagePrefiller * Check if the required methods in the Module is loaded. * @return True if the Module is loaded, false otherwise. */ - inline bool is_method_loaded() override { + inline bool is_method_loaded() { ::executorch::runtime::Result> methods_res = module_->method_names(); if (methods_res.error() != ::executorch::runtime::Error::Ok) { @@ -88,16 +88,19 @@ class ET_EXPERIMENTAL LlavaImagePrefiller ET_CHECK_MSG( methods_exist, "Missing required methods (%s, %s) in the model", - kImageEncoderMethod.c_str(), - kTextModelMethod.c_str()); + kImageEncoderMethod, + kTextModelMethod); } bool methods_loaded = module_->is_method_loaded(kImageEncoderMethod) && module_->is_method_loaded(kTextModelMethod); return methods_loaded; } - inline static const std::string kImageEncoderMethod = "image_encoder"; - inline static const std::string kTextModelMethod = "text_model"; + inline static constexpr auto kImageEncoderMethod = "image_encoder"; + inline static constexpr auto kTextModelMethod = "text_model"; + + private: + ::executorch::extension::Module* module_; }; } // namespace example diff --git a/examples/models/llava/runner/llava_runner.h b/examples/models/llava/runner/llava_runner.h index 29e3097c6cf..184522c2cf1 100644 --- a/examples/models/llava/runner/llava_runner.h +++ b/examples/models/llava/runner/llava_runner.h @@ -10,29 +10,50 @@ // processing logic. #pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include #include #include #include #include -#include -#include - -#include namespace example { -class ET_EXPERIMENTAL LlavaRunner - : public ::executorch::extension::llm::MultimodalRunner { +using executorch::extension::Module; +using executorch::extension::llm::ImagePrefiller; +using executorch::extension::llm::IOManager; +using executorch::extension::llm::Stats; +using executorch::extension::llm::TextDecoderRunner; +using executorch::extension::llm::TextPrefiller; +using executorch::extension::llm::TextTokenGenerator; + +class ET_EXPERIMENTAL LlavaRunner { public: explicit LlavaRunner( const std::string& model_path, const std::string& tokenizer_path, const float temperature = 0.8f) - : MultimodalRunner(model_path, tokenizer_path, temperature){}; + : temperature_(temperature), + module_(std::make_unique(model_path, Module::LoadMode::File)), + io_manager_(std::make_unique()), + tokenizer_path_(tokenizer_path) { + ET_LOG( + Info, + "Creating Llava runner: model_path=%s, tokenizer_path=%s", + model_path.c_str(), + tokenizer_path.c_str()); + } - bool is_loaded() override; + bool is_loaded(); - ::executorch::runtime::Error load() override; + ::executorch::runtime::Error load(); ::executorch::runtime::Error generate( std::vector<::executorch::extension::llm::Image> images, @@ -41,17 +62,17 @@ class ET_EXPERIMENTAL LlavaRunner std::function token_callback = {}, std::function stats_callback = {}, - bool echo = true) override; + bool echo = true); ::executorch::runtime::Error prefill_images( std::vector<::executorch::extension::llm::Image>& images, - int64_t& start_pos) override; + int64_t& start_pos); ::executorch::runtime::Result prefill_prompt( const std::string& prompt, int64_t& start_pos, int8_t bos = 0, - int8_t eos = 0) override; + int8_t eos = 0); ::executorch::runtime::Error generate_from_pos( const std::string& prompt, @@ -60,9 +81,30 @@ class ET_EXPERIMENTAL LlavaRunner std::function token_callback = {}, std::function stats_callback = {}, - bool echo = true) override; + bool echo = true); + + inline void stop() { + text_token_generator_->stop(); + } private: + // metadata + float temperature_; + + // model + std::unordered_set model_methods_; + std::unique_ptr module_; + std::unique_ptr text_decoder_runner_; + std::unique_ptr text_prefiller_; + std::unique_ptr image_prefiller_; + std::unique_ptr io_manager_; + std::unique_ptr text_token_generator_; + std::string tokenizer_path_; + std::unique_ptr<::tokenizers::Tokenizer> tokenizer_; + + // stats + Stats stats_; + inline static const char* kPresetPrompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: "; }; diff --git a/examples/models/llava/runner/targets.bzl b/examples/models/llava/runner/targets.bzl index 074c92b35e3..6a02e59c6ae 100644 --- a/examples/models/llava/runner/targets.bzl +++ b/examples/models/llava/runner/targets.bzl @@ -20,7 +20,7 @@ def define_common_targets(): "//executorch/kernels/quantized:generated_lib", "//executorch/runtime/core/exec_aten:lib", "//executorch/runtime/core/exec_aten/util:tensor_util", - "//executorch/configurations:optimized_native_cpu_ops", + "//executorch/configurations:optimized_native_cpu_ops", "//executorch/extension/llm/custom_ops:custom_ops", "//pytorch/tokenizers:llama2c_tokenizer", ], diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 257f7282c65..48bc62141a2 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -115,7 +115,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { float temperature_ = 0.0f; int model_type_category_; std::unique_ptr runner_; - std::unique_ptr multi_modal_runner_; + std::unique_ptr multi_modal_runner_; public: constexpr static auto kJavaDescriptor =