Skip to content

Commit ad8aad7

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Decouple LlavaRunner from multimodal runner
Summary: Make sure upcoming changes to `MultimodalRunner` class doesn't break existing `LlavaRunner`. Eventually `LlavaRunner` will be deprecated and we will use `MultimodalRunner` for Llava demo and app integration. Reviewed By: jackzhxng Differential Revision: D78997241
1 parent 8651d31 commit ad8aad7

File tree

4 files changed

+65
-24
lines changed

4 files changed

+65
-24
lines changed

examples/models/llava/runner/llava_image_prefiller.h

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@
1515

1616
namespace example {
1717

18-
class ET_EXPERIMENTAL LlavaImagePrefiller
19-
: public ::executorch::extension::llm::ImagePrefiller {
18+
class ET_EXPERIMENTAL LlavaImagePrefiller {
2019
public:
2120
explicit LlavaImagePrefiller(::executorch::extension::Module* module)
22-
: ImagePrefiller(module){};
21+
: module_(module){};
2322
/**
2423
* Prefill an LLM Module with the given image input.
2524
* @param image The image input to LLaVa.
@@ -28,7 +27,7 @@ class ET_EXPERIMENTAL LlavaImagePrefiller
2827
*/
2928
inline ::executorch::runtime::Result<executorch::aten::Tensor> prefill(
3029
::executorch::extension::llm::Image& image,
31-
int64_t& start_pos) override {
30+
int64_t& start_pos) {
3231
auto image_tensor = executorch::extension::from_blob(
3332
image.data.data(),
3433
{3, image.height, image.width},
@@ -59,7 +58,7 @@ class ET_EXPERIMENTAL LlavaImagePrefiller
5958
* Load the Module for image prefill purpose.
6059
* @return The error code.
6160
*/
62-
inline ::executorch::runtime::Error load() override {
61+
inline ::executorch::runtime::Error load() {
6362
if (is_method_loaded()) {
6463
return ::executorch::runtime::Error::Ok;
6564
}
@@ -72,7 +71,7 @@ class ET_EXPERIMENTAL LlavaImagePrefiller
7271
* Check if the required methods in the Module is loaded.
7372
* @return True if the Module is loaded, false otherwise.
7473
*/
75-
inline bool is_method_loaded() override {
74+
inline bool is_method_loaded() {
7675
::executorch::runtime::Result<std::unordered_set<std::string>> methods_res =
7776
module_->method_names();
7877
if (methods_res.error() != ::executorch::runtime::Error::Ok) {
@@ -88,16 +87,20 @@ class ET_EXPERIMENTAL LlavaImagePrefiller
8887
ET_CHECK_MSG(
8988
methods_exist,
9089
"Missing required methods (%s, %s) in the model",
91-
kImageEncoderMethod.c_str(),
92-
kTextModelMethod.c_str());
90+
kImageEncoderMethod,
91+
kTextModelMethod);
9392
}
9493
bool methods_loaded = module_->is_method_loaded(kImageEncoderMethod) &&
9594
module_->is_method_loaded(kTextModelMethod);
9695
return methods_loaded;
9796
}
9897

99-
inline static const std::string kImageEncoderMethod = "image_encoder";
100-
inline static const std::string kTextModelMethod = "text_model";
98+
inline static constexpr auto kImageEncoderMethod = "image_encoder";
99+
inline static constexpr auto kTextModelMethod = "text_model";
100+
101+
private:
102+
::executorch::extension::Module* module_;
103+
101104
};
102105

103106
} // namespace example

examples/models/llava/runner/llava_runner.h

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,50 @@
1010
// processing logic.
1111
#pragma once
1212

13+
#include <executorch/examples/models/llava/runner/llava_image_prefiller.h>
14+
#include <executorch/extension/llm/runner/image.h>
15+
#include <executorch/extension/llm/runner/io_manager/io_manager.h>
16+
#include <executorch/extension/llm/runner/irunner.h>
17+
#include <executorch/extension/llm/runner/stats.h>
18+
#include <executorch/extension/llm/runner/text_decoder_runner.h>
19+
#include <executorch/extension/llm/runner/text_prefiller.h>
20+
#include <executorch/extension/llm/runner/text_token_generator.h>
21+
#include <executorch/extension/module/module.h>
1322
#include <cstdint>
1423
#include <functional>
1524
#include <memory>
1625
#include <string>
17-
#include <type_traits>
18-
#include <unordered_map>
1926

20-
#include <executorch/extension/llm/runner/multimodal_runner.h>
27+
using executorch::extension::Module;
28+
using executorch::extension::llm::ImagePrefiller;
29+
using executorch::extension::llm::IOManager;
30+
using executorch::extension::llm::Stats;
31+
using executorch::extension::llm::TextDecoderRunner;
32+
using executorch::extension::llm::TextPrefiller;
33+
using executorch::extension::llm::TextTokenGenerator;
2134

2235
namespace example {
2336

24-
class ET_EXPERIMENTAL LlavaRunner
25-
: public ::executorch::extension::llm::MultimodalRunner {
37+
class ET_EXPERIMENTAL LlavaRunner {
2638
public:
2739
explicit LlavaRunner(
2840
const std::string& model_path,
2941
const std::string& tokenizer_path,
3042
const float temperature = 0.8f)
31-
: MultimodalRunner(model_path, tokenizer_path, temperature){};
43+
: temperature_(temperature),
44+
module_(std::make_unique<Module>(model_path, Module::LoadMode::File)),
45+
io_manager_(std::make_unique<IOManager>()),
46+
tokenizer_path_(tokenizer_path) {
47+
ET_LOG(
48+
Info,
49+
"Creating Llava runner: model_path=%s, tokenizer_path=%s",
50+
model_path.c_str(),
51+
tokenizer_path.c_str());
52+
}
3253

33-
bool is_loaded() override;
54+
bool is_loaded();
3455

35-
::executorch::runtime::Error load() override;
56+
::executorch::runtime::Error load();
3657

3758
::executorch::runtime::Error generate(
3859
std::vector<::executorch::extension::llm::Image> images,
@@ -41,17 +62,17 @@ class ET_EXPERIMENTAL LlavaRunner
4162
std::function<void(const std::string&)> token_callback = {},
4263
std::function<void(const ::executorch::extension::llm::Stats&)>
4364
stats_callback = {},
44-
bool echo = true) override;
65+
bool echo = true);
4566

4667
::executorch::runtime::Error prefill_images(
4768
std::vector<::executorch::extension::llm::Image>& images,
48-
int64_t& start_pos) override;
69+
int64_t& start_pos);
4970

5071
::executorch::runtime::Result<uint64_t> prefill_prompt(
5172
const std::string& prompt,
5273
int64_t& start_pos,
5374
int8_t bos = 0,
54-
int8_t eos = 0) override;
75+
int8_t eos = 0);
5576

5677
::executorch::runtime::Error generate_from_pos(
5778
const std::string& prompt,
@@ -60,9 +81,26 @@ class ET_EXPERIMENTAL LlavaRunner
6081
std::function<void(const std::string&)> token_callback = {},
6182
std::function<void(const ::executorch::extension::llm::Stats&)>
6283
stats_callback = {},
63-
bool echo = true) override;
84+
bool echo = true);
6485

6586
private:
87+
// metadata
88+
float temperature_;
89+
90+
// model
91+
std::unordered_set<std::string> model_methods_;
92+
std::unique_ptr<Module> module_;
93+
std::unique_ptr<TextDecoderRunner> text_decoder_runner_;
94+
std::unique_ptr<TextPrefiller> text_prefiller_;
95+
std::unique_ptr<LlavaImagePrefiller> image_prefiller_;
96+
std::unique_ptr<IOManager> io_manager_;
97+
std::unique_ptr<TextTokenGenerator> text_token_generator_;
98+
std::string tokenizer_path_;
99+
std::unique_ptr<::tokenizers::Tokenizer> tokenizer_;
100+
101+
// stats
102+
Stats stats_;
103+
66104
inline static const char* kPresetPrompt =
67105
"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: ";
68106
};

examples/models/llava/runner/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def define_common_targets():
2020
"//executorch/kernels/quantized:generated_lib",
2121
"//executorch/runtime/core/exec_aten:lib",
2222
"//executorch/runtime/core/exec_aten/util:tensor_util",
23-
"//executorch/configurations:optimized_native_cpu_ops",
23+
"//executorch/configurations:optimized_native_cpu_ops",
2424
"//executorch/extension/llm/custom_ops:custom_ops",
2525
"//pytorch/tokenizers:llama2c_tokenizer",
2626
],

extension/android/jni/jni_layer_llama.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass<ExecuTorchLlmJni> {
115115
float temperature_ = 0.0f;
116116
int model_type_category_;
117117
std::unique_ptr<llm::IRunner> runner_;
118-
std::unique_ptr<llm::MultimodalRunner> multi_modal_runner_;
118+
std::unique_ptr<example::LlavaRunner> multi_modal_runner_;
119119

120120
public:
121121
constexpr static auto kJavaDescriptor =

0 commit comments

Comments
 (0)