Skip to content

Commit 48e4822

Browse files
authored
Decouple LlavaRunner from multimodal runner
Differential Revision: D78997241 Pull Request resolved: #13067
1 parent d36e83a commit 48e4822

File tree

4 files changed

+70
-25
lines changed

4 files changed

+70
-25
lines changed

examples/models/llava/runner/llava_image_prefiller.h

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515

1616
namespace example {
1717

18-
class ET_EXPERIMENTAL LlavaImagePrefiller
19-
: public ::executorch::extension::llm::ImagePrefiller {
18+
class ET_EXPERIMENTAL LlavaImagePrefiller {
2019
public:
2120
explicit LlavaImagePrefiller(::executorch::extension::Module* module)
22-
: ImagePrefiller(module){};
21+
: module_(module) {}
22+
2323
/**
2424
* Prefill an LLM Module with the given image input.
2525
* @param image The image input to LLaVa.
@@ -28,7 +28,7 @@ class ET_EXPERIMENTAL LlavaImagePrefiller
2828
*/
2929
inline ::executorch::runtime::Result<executorch::aten::Tensor> prefill(
3030
::executorch::extension::llm::Image& image,
31-
int64_t& start_pos) override {
31+
int64_t& start_pos) {
3232
auto image_tensor = executorch::extension::from_blob(
3333
image.data.data(),
3434
{3, image.height, image.width},
@@ -59,7 +59,7 @@ class ET_EXPERIMENTAL LlavaImagePrefiller
5959
* Load the Module for image prefill purpose.
6060
* @return The error code.
6161
*/
62-
inline ::executorch::runtime::Error load() override {
62+
inline ::executorch::runtime::Error load() {
6363
if (is_method_loaded()) {
6464
return ::executorch::runtime::Error::Ok;
6565
}
@@ -72,7 +72,7 @@ class ET_EXPERIMENTAL LlavaImagePrefiller
7272
* Check if the required methods in the Module is loaded.
7373
* @return True if the Module is loaded, false otherwise.
7474
*/
75-
inline bool is_method_loaded() override {
75+
inline bool is_method_loaded() {
7676
::executorch::runtime::Result<std::unordered_set<std::string>> methods_res =
7777
module_->method_names();
7878
if (methods_res.error() != ::executorch::runtime::Error::Ok) {
@@ -88,16 +88,19 @@ class ET_EXPERIMENTAL LlavaImagePrefiller
8888
ET_CHECK_MSG(
8989
methods_exist,
9090
"Missing required methods (%s, %s) in the model",
91-
kImageEncoderMethod.c_str(),
92-
kTextModelMethod.c_str());
91+
kImageEncoderMethod,
92+
kTextModelMethod);
9393
}
9494
bool methods_loaded = module_->is_method_loaded(kImageEncoderMethod) &&
9595
module_->is_method_loaded(kTextModelMethod);
9696
return methods_loaded;
9797
}
9898

99-
inline static const std::string kImageEncoderMethod = "image_encoder";
100-
inline static const std::string kTextModelMethod = "text_model";
99+
inline static constexpr auto kImageEncoderMethod = "image_encoder";
100+
inline static constexpr auto kTextModelMethod = "text_model";
101+
102+
private:
103+
::executorch::extension::Module* module_;
101104
};
102105

103106
} // namespace example

examples/models/llava/runner/llava_runner.h

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,50 @@
1010
// processing logic.
1111
#pragma once
1212

13+
#include <executorch/examples/models/llava/runner/llava_image_prefiller.h>
14+
#include <executorch/extension/llm/runner/image.h>
15+
#include <executorch/extension/llm/runner/io_manager/io_manager.h>
16+
#include <executorch/extension/llm/runner/irunner.h>
17+
#include <executorch/extension/llm/runner/stats.h>
18+
#include <executorch/extension/llm/runner/text_decoder_runner.h>
19+
#include <executorch/extension/llm/runner/text_prefiller.h>
20+
#include <executorch/extension/llm/runner/text_token_generator.h>
21+
#include <executorch/extension/module/module.h>
1322
#include <cstdint>
1423
#include <functional>
1524
#include <memory>
1625
#include <string>
17-
#include <type_traits>
18-
#include <unordered_map>
19-
20-
#include <executorch/extension/llm/runner/multimodal_runner.h>
2126

2227
namespace example {
2328

24-
class ET_EXPERIMENTAL LlavaRunner
25-
: public ::executorch::extension::llm::MultimodalRunner {
29+
using executorch::extension::Module;
30+
using executorch::extension::llm::ImagePrefiller;
31+
using executorch::extension::llm::IOManager;
32+
using executorch::extension::llm::Stats;
33+
using executorch::extension::llm::TextDecoderRunner;
34+
using executorch::extension::llm::TextPrefiller;
35+
using executorch::extension::llm::TextTokenGenerator;
36+
37+
class ET_EXPERIMENTAL LlavaRunner {
2638
public:
2739
explicit LlavaRunner(
2840
const std::string& model_path,
2941
const std::string& tokenizer_path,
3042
const float temperature = 0.8f)
31-
: MultimodalRunner(model_path, tokenizer_path, temperature){};
43+
: temperature_(temperature),
44+
module_(std::make_unique<Module>(model_path, Module::LoadMode::File)),
45+
io_manager_(std::make_unique<IOManager>()),
46+
tokenizer_path_(tokenizer_path) {
47+
ET_LOG(
48+
Info,
49+
"Creating Llava runner: model_path=%s, tokenizer_path=%s",
50+
model_path.c_str(),
51+
tokenizer_path.c_str());
52+
}
3253

33-
bool is_loaded() override;
54+
bool is_loaded();
3455

35-
::executorch::runtime::Error load() override;
56+
::executorch::runtime::Error load();
3657

3758
::executorch::runtime::Error generate(
3859
std::vector<::executorch::extension::llm::Image> images,
@@ -41,17 +62,17 @@ class ET_EXPERIMENTAL LlavaRunner
4162
std::function<void(const std::string&)> token_callback = {},
4263
std::function<void(const ::executorch::extension::llm::Stats&)>
4364
stats_callback = {},
44-
bool echo = true) override;
65+
bool echo = true);
4566

4667
::executorch::runtime::Error prefill_images(
4768
std::vector<::executorch::extension::llm::Image>& images,
48-
int64_t& start_pos) override;
69+
int64_t& start_pos);
4970

5071
::executorch::runtime::Result<uint64_t> prefill_prompt(
5172
const std::string& prompt,
5273
int64_t& start_pos,
5374
int8_t bos = 0,
54-
int8_t eos = 0) override;
75+
int8_t eos = 0);
5576

5677
::executorch::runtime::Error generate_from_pos(
5778
const std::string& prompt,
@@ -60,9 +81,30 @@ class ET_EXPERIMENTAL LlavaRunner
6081
std::function<void(const std::string&)> token_callback = {},
6182
std::function<void(const ::executorch::extension::llm::Stats&)>
6283
stats_callback = {},
63-
bool echo = true) override;
84+
bool echo = true);
85+
86+
inline void stop() {
87+
text_token_generator_->stop();
88+
}
6489

6590
private:
91+
// metadata
92+
float temperature_;
93+
94+
// model
95+
std::unordered_set<std::string> model_methods_;
96+
std::unique_ptr<Module> module_;
97+
std::unique_ptr<TextDecoderRunner> text_decoder_runner_;
98+
std::unique_ptr<TextPrefiller> text_prefiller_;
99+
std::unique_ptr<LlavaImagePrefiller> image_prefiller_;
100+
std::unique_ptr<IOManager> io_manager_;
101+
std::unique_ptr<TextTokenGenerator> text_token_generator_;
102+
std::string tokenizer_path_;
103+
std::unique_ptr<::tokenizers::Tokenizer> tokenizer_;
104+
105+
// stats
106+
Stats stats_;
107+
66108
inline static const char* kPresetPrompt =
67109
"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: ";
68110
};

examples/models/llava/runner/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def define_common_targets():
2020
"//executorch/kernels/quantized:generated_lib",
2121
"//executorch/runtime/core/exec_aten:lib",
2222
"//executorch/runtime/core/exec_aten/util:tensor_util",
23-
"//executorch/configurations:optimized_native_cpu_ops",
23+
"//executorch/configurations:optimized_native_cpu_ops",
2424
"//executorch/extension/llm/custom_ops:custom_ops",
2525
"//pytorch/tokenizers:llama2c_tokenizer",
2626
],

extension/android/jni/jni_layer_llama.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
115115
float temperature_ = 0.0f;
116116
int model_type_category_;
117117
std::unique_ptr<llm::IRunner> runner_;
118-
std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
118+
std::unique_ptr<example::LlavaRunner> multi_modal_runner_;
119119

120120
public:
121121
constexpr static auto kJavaDescriptor =

0 commit comments

Comments
 (0)