Skip to content

Commit 56f24c6

Browse files
authored
Add a default image prefiller implementation
Differential Revision: D80063769 Pull Request resolved: #13310
1 parent 80250f8 commit 56f24c6

14 files changed

+1023
-24
lines changed

extension/llm/runner/CMakeLists.txt

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,18 @@ list(TRANSFORM _extension_llm_runner__srcs PREPEND "${EXECUTORCH_ROOT}/")
3939

4040
add_library(extension_llm_runner STATIC ${_extension_llm_runner__srcs})
4141

42-
set(runner_deps executorch_core extension_module extension_tensor tokenizers)
42+
set(runner_deps executorch_core extension_module extension_tensor
43+
tokenizers::tokenizers
44+
)
45+
46+
# depend on arange_utils
47+
if(NOT TARGET kernels_util_all_deps)
48+
add_subdirectory(
49+
${EXECUTORCH_ROOT}/kernels/portable/cpu/util
50+
${CMAKE_CURRENT_BINARY_DIR}/kernels_util
51+
)
52+
endif()
53+
list(APPEND runner_deps kernels_util_all_deps)
4354

4455
target_link_libraries(extension_llm_runner PUBLIC ${runner_deps})
4556
set_target_properties(
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/llm/runner/constants.h>
10+
#include <executorch/extension/llm/runner/text_decoder_runner.h>
11+
12+
namespace executorch::extension::llm {
13+
14+
class ET_EXPERIMENTAL MultimodalDecoderRunner
15+
: public executorch::extension::llm::TextDecoderRunner {
16+
public:
17+
explicit MultimodalDecoderRunner(Module* module, IOManager* io_manager)
18+
: TextDecoderRunner(module, io_manager) {}
19+
20+
/**
21+
* Step the LLM Decoder with the given tokens and start position.
22+
* @param tokens The tokens to the LLM.
23+
* @param start_pos The start position of the tokens.
24+
* @return The logits tensor.
25+
*/
26+
inline executorch::runtime::Result<executorch::aten::Tensor> step(
27+
executorch::extension::TensorPtr& tokens,
28+
int64_t start_pos) override {
29+
// run token embedding
30+
auto token_embedding_outputs =
31+
ET_UNWRAP(module_->execute(kTokenEmbeddingMethod, tokens));
32+
33+
// Return the logits tensor
34+
return decode(token_embedding_outputs[0], start_pos);
35+
}
36+
37+
/**
38+
* Decode the embeddings to logits.
39+
* @param embeddings The embeddings tensor.
40+
* @param start_pos The start position of the embeddings.
41+
* @return The logits tensor.
42+
*/
43+
inline executorch::runtime::Result<executorch::aten::Tensor> decode(
44+
const runtime::EValue& embeddings,
45+
int64_t start_pos) {
46+
auto start_pos_tensor = ::executorch::extension::from_blob(
47+
&start_pos, {1}, executorch::aten::ScalarType::Long);
48+
// run text model
49+
auto outputs_res = ET_UNWRAP(
50+
module_->execute(kTextModelMethod, {start_pos_tensor, embeddings}));
51+
52+
ET_CHECK_MSG(
53+
outputs_res.size() == 1,
54+
"More then one output returned from executing LLM.");
55+
ET_CHECK_MSG(
56+
outputs_res[0].isTensor(),
57+
"Non Tensor Output returned from executing LLM");
58+
59+
// Return the logits tensor
60+
return outputs_res[0].toTensor();
61+
}
62+
63+
/**
64+
* Load the Module for text decode purpose.
65+
* @return The error code.
66+
*/
67+
inline executorch::runtime::Error load() override {
68+
if (is_method_loaded()) {
69+
return executorch::runtime::Error::Ok;
70+
}
71+
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTokenEmbeddingMethod));
72+
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTextModelMethod));
73+
return executorch::runtime::Error::Ok;
74+
}
75+
76+
/**
77+
* Check if the required methods in the Module is loaded.
78+
* @return True if the Module is loaded, false otherwise.
79+
*/
80+
inline bool is_method_loaded() override {
81+
executorch::runtime::Result<std::unordered_set<std::string>> methods_res =
82+
module_->method_names();
83+
if (methods_res.error() != executorch::runtime::Error::Ok) {
84+
ET_CHECK_MSG(false, "Failed to get method names");
85+
}
86+
std::unordered_set<std::string> methods = methods_res.get();
87+
bool methods_exist = methods.find(kTokenEmbeddingMethod) != methods.end() &&
88+
methods.find(kTextModelMethod) != methods.end();
89+
if (!methods_exist) {
90+
for (const auto& method : methods) {
91+
ET_LOG(Error, "Method: %s", method.c_str());
92+
}
93+
ET_CHECK_MSG(
94+
methods_exist,
95+
"Missing required methods (%s, %s) in the model",
96+
kTokenEmbeddingMethod,
97+
kTextModelMethod);
98+
}
99+
bool methods_loaded = module_->is_method_loaded(kTokenEmbeddingMethod) &&
100+
module_->is_method_loaded(kTextModelMethod);
101+
return methods_loaded;
102+
}
103+
};
104+
105+
} // namespace executorch::extension::llm
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// @lint-ignore-every CLANGTIDY facebook-hte-Deprecated
10+
// A generic multimodal input class that can hold either image or text data.
11+
12+
#pragma once
13+
14+
#include <executorch/extension/llm/runner/image.h>
15+
#include <executorch/runtime/platform/compiler.h>
16+
#include <string>
17+
#include <variant>
18+
19+
namespace executorch {
20+
namespace extension {
21+
namespace llm {
22+
23+
/**
24+
* A generic class to hold either image or text data for multimodal inputs.
25+
* This allows the generate() API to take a std::vector of these objects
26+
* instead of separate image and text parameters.
27+
*/
28+
class ET_EXPERIMENTAL MultimodalInput {
29+
public:
30+
enum class Type { TEXT, IMAGE };
31+
32+
// Constructors
33+
explicit MultimodalInput(const std::string& text) : data_(text) {}
34+
explicit MultimodalInput(std::string&& text) : data_(std::move(text)) {}
35+
explicit MultimodalInput(const Image& image) : data_(image) {}
36+
explicit MultimodalInput(Image&& image) : data_(std::move(image)) {}
37+
38+
// Copy constructor and assignment
39+
MultimodalInput(const MultimodalInput& other) = default;
40+
MultimodalInput& operator=(const MultimodalInput& other) = default;
41+
42+
// Move constructor and assignment
43+
MultimodalInput(MultimodalInput&& other) noexcept = default;
44+
MultimodalInput& operator=(MultimodalInput&& other) noexcept = default;
45+
46+
// Destructor
47+
~MultimodalInput() = default;
48+
49+
/**
50+
* Check if this input contains text data.
51+
* @return true if this input contains text, false otherwise.
52+
*/
53+
bool is_text() const noexcept {
54+
return std::holds_alternative<std::string>(data_);
55+
}
56+
57+
/**
58+
* Check if this input contains image data.
59+
* @return true if this input contains an image, false otherwise.
60+
*/
61+
bool is_image() const noexcept {
62+
return std::holds_alternative<Image>(data_);
63+
}
64+
65+
/**
66+
* Get the type of data stored in this input.
67+
* @return Type::TEXT if text data, Type::IMAGE if image data.
68+
*/
69+
Type get_type() const noexcept {
70+
return is_text() ? Type::TEXT : Type::IMAGE;
71+
}
72+
73+
/**
74+
* Get the text data from this input.
75+
* @return Reference to the stored text string.
76+
* @throws std::bad_variant_access if this input doesn't contain text.
77+
*/
78+
const std::string& get_text() const& {
79+
return std::get<std::string>(data_);
80+
}
81+
82+
/**
83+
* Get the text data from this input (mutable version).
84+
* @return Mutable reference to the stored text string.
85+
* @throws std::bad_variant_access if this input doesn't contain text.
86+
*/
87+
std::string& get_text() & {
88+
return std::get<std::string>(data_);
89+
}
90+
91+
/**
92+
* Get the text data from this input (rvalue version).
93+
* @return Rvalue reference to the stored text string for efficient moves.
94+
* @throws std::bad_variant_access if this input doesn't contain text.
95+
*/
96+
std::string&& get_text() && {
97+
return std::get<std::string>(std::move(data_));
98+
}
99+
100+
/**
101+
* Get the image data from this input.
102+
* @return Reference to the stored Image object.
103+
* @throws std::bad_variant_access if this input doesn't contain an image.
104+
*/
105+
const Image& get_image() const& {
106+
return std::get<Image>(data_);
107+
}
108+
109+
/**
110+
* Get the image data from this input (mutable version).
111+
* @return Mutable reference to the stored Image object.
112+
* @throws std::bad_variant_access if this input doesn't contain an image.
113+
*/
114+
Image& get_image() & {
115+
return std::get<Image>(data_);
116+
}
117+
118+
/**
119+
* Get the image data from this input (rvalue version).
120+
* @return Rvalue reference to the stored Image object for efficient moves.
121+
* @throws std::bad_variant_access if this input doesn't contain an image.
122+
*/
123+
Image&& get_image() && {
124+
return std::get<Image>(std::move(data_));
125+
}
126+
127+
/**
128+
* Try to get the text data from this input safely.
129+
* @return Pointer to the text string if this input contains text, nullptr
130+
* otherwise.
131+
*/
132+
const std::string* try_get_text() const noexcept {
133+
return std::get_if<std::string>(&data_);
134+
}
135+
136+
/**
137+
* Try to get the text data from this input safely (mutable version).
138+
* @return Pointer to the text string if this input contains text, nullptr
139+
* otherwise.
140+
*/
141+
std::string* try_get_text() noexcept {
142+
return std::get_if<std::string>(&data_);
143+
}
144+
145+
/**
146+
* Try to get the image data from this input safely.
147+
* @return Pointer to the Image object if this input contains an image,
148+
* nullptr otherwise.
149+
*/
150+
const Image* try_get_image() const noexcept {
151+
return std::get_if<Image>(&data_);
152+
}
153+
154+
/**
155+
* Try to get the image data from this input safely (mutable version).
156+
* @return Pointer to the Image object if this input contains an image,
157+
* nullptr otherwise.
158+
*/
159+
Image* try_get_image() noexcept {
160+
return std::get_if<Image>(&data_);
161+
}
162+
163+
private:
164+
std::variant<std::string, Image> data_;
165+
};
166+
167+
// Convenience factory functions
168+
inline MultimodalInput make_text_input(const std::string& text) noexcept {
169+
return MultimodalInput(text);
170+
}
171+
172+
inline MultimodalInput make_text_input(std::string&& text) noexcept {
173+
return MultimodalInput(std::move(text));
174+
}
175+
176+
inline MultimodalInput make_image_input(const Image& image) noexcept {
177+
return MultimodalInput(image);
178+
}
179+
180+
inline MultimodalInput make_image_input(Image&& image) noexcept {
181+
return MultimodalInput(std::move(image));
182+
}
183+
184+
} // namespace llm
185+
} // namespace extension
186+
} // namespace executorch

0 commit comments

Comments
 (0)