|
6 | 6 | * LICENSE file in the root directory of this source tree. |
7 | 7 | */ |
8 | 8 |
|
9 | | -// Generic encoder prefiller that handles multimodal inputs (text, image and |
10 | | -// audio (to be implemented)) to prefill the KV cache of a multimodal LLM. |
11 | | -// @lint-ignore-every CLANGTIDY facebook-hte-Deprecated |
| 9 | +// Generic encoder prefiller that handles multimodal inputs (image and audio) |
| 10 | +// to prefill the KV cache of a multimodal LLM. |
12 | 11 |
|
13 | | -#include <executorch/extension/llm/runner/constants.h> |
14 | | -#include <executorch/extension/llm/runner/multimodal_prefiller.h> |
15 | | -#include <executorch/extension/llm/runner/util.h> |
16 | | -#include <executorch/extension/tensor/tensor.h> |
| 12 | +#pragma once |
17 | 13 |
|
18 | | -namespace executorch::extension::llm { |
19 | | - |
20 | | -MultimodalPrefiller::MultimodalPrefiller( |
21 | | - Module* module, |
22 | | - MultimodalDecoderRunner* decoder_runner, |
23 | | - Tokenizer* tokenizer, |
24 | | - IOManager* io_manager) |
25 | | - : module_(module), |
26 | | - text_decoder_runner_(decoder_runner), |
27 | | - tokenizer_(tokenizer), |
28 | | - io_manager_(io_manager) {} |
29 | | - |
30 | | -/** |
31 | | - * Prefill an LLM Module with the given multimodal input. |
32 | | - * @param input The multimodal input (text, image or audio) to the multimodal |
33 | | - * LLM. |
34 | | - * @param start_pos The starting position in KV cache of the input in the LLM |
35 | | - * @return logits of the prefill. |
36 | | - */ |
37 | | -Result<uint64_t> MultimodalPrefiller::prefill( |
38 | | - const MultimodalInput& input, |
39 | | - int64_t& start_pos) { |
40 | | - // 1. Run encoder model. |
41 | | - ::executorch::runtime::EValue encoder_output; |
42 | | - if (input.is_image()) { |
43 | | - Image image = input.get_image(); |
44 | | - |
45 | | - auto method_meta = ET_UNWRAP( |
46 | | - module_->method_meta(kImageEncoderMethod), |
47 | | - "Failed to get method_meta for %s", |
48 | | - kImageEncoderMethod); |
49 | | - |
50 | | - ET_CHECK_MSG( |
51 | | - method_meta.num_inputs() > 0, |
52 | | - "Image encoder should have at least 1 input"); |
53 | | - auto input_meta = ET_UNWRAP( |
54 | | - method_meta.input_tensor_meta(0), |
55 | | - "Cannot get input tensor meta at index 0"); |
56 | | - auto expected_dtype = input_meta.scalar_type(); |
57 | | - |
58 | | - if (expected_dtype == ::executorch::aten::ScalarType::Float) { |
59 | | - ET_CHECK_MSG( |
60 | | - image.is_float(), |
61 | | - "Model expects float image data, but image has uint8_t data."); |
62 | | - } else if (expected_dtype == ::executorch::aten::ScalarType::Byte) { |
63 | | - ET_CHECK_MSG( |
64 | | - image.is_uint8(), |
65 | | - "Model expects uint8_t image data, but image has float data."); |
66 | | - } else { |
67 | | - ET_LOG( |
68 | | - Error, |
69 | | - "Unsupported image encoder input dtype: %s", |
70 | | - ::executorch::runtime::toString(expected_dtype)); |
71 | | - return ::executorch::runtime::Error::NotSupported; |
72 | | - } |
73 | | - |
74 | | - // The model might expect a 4D tensor (NCHW), but toTensor() returns a 3D |
75 | | - // tensor (CHW). Add a batch dimension of 1 if needed. |
76 | | - auto expected_dims = input_meta.sizes(); |
77 | | - auto image_tensor = ET_UNWRAP( |
78 | | - image.toTensor(/*with_batch*/ expected_dims.size() == 4), |
79 | | - "Failed to convert image to tensor"); |
80 | | - |
81 | | - // Run image encoder |
82 | | - auto image_encoder_outputs = |
83 | | - ET_UNWRAP(module_->execute(kImageEncoderMethod, image_tensor)); |
84 | | - |
85 | | - encoder_output = image_encoder_outputs[0]; |
86 | | - } else if (input.is_audio()) { |
87 | | - Audio audio = input.get_audio(); |
88 | | - |
89 | | - // Use the original tensor shape as intended |
90 | | - auto audio_tensor = executorch::extension::from_blob( |
91 | | - audio.data.data(), |
92 | | - {audio.batch_size, audio.n_bins, audio.n_frames}, |
93 | | - ::executorch::aten::ScalarType::Float); |
94 | | - |
95 | | - // Run audio encoder |
96 | | - auto audio_encoder_result = |
97 | | - module_->execute(kAudioEncoderMethod, audio_tensor); |
98 | | - if (audio_encoder_result.error() != ::executorch::runtime::Error::Ok) { |
99 | | - return ::executorch::runtime::Error::Internal; |
100 | | - } |
101 | | - auto audio_encoder_outputs = audio_encoder_result.get(); |
102 | | - |
103 | | - encoder_output = audio_encoder_outputs[0]; |
104 | | - } else if (input.is_text()) { |
105 | | - auto& text = input.get_text(); |
106 | | - std::vector<uint64_t> tokens = |
107 | | - ET_UNWRAP_TOKENIZER(tokenizer_->encode(text)); |
108 | | - |
109 | | - auto text_tensor = executorch::extension::from_blob( |
110 | | - tokens.data(), |
111 | | - {1, static_cast<aten::SizesType>(tokens.size())}, |
112 | | - ::executorch::aten::ScalarType::Long); |
113 | | - |
114 | | - // Run text encoder (token embeddings) |
115 | | - auto token_embedding_outputs = |
116 | | - ET_UNWRAP(module_->execute(kTokenEmbeddingMethod, text_tensor)); |
| 14 | +#include <executorch/extension/llm/runner/multimodal_decoder_runner.h> |
| 15 | +#include <executorch/extension/llm/runner/multimodal_input.h> |
| 16 | +#include <executorch/extension/llm/runner/text_decoder_runner.h> |
| 17 | +#include <executorch/extension/llm/sampler/sampler.h> |
| 18 | +#include <executorch/extension/module/module.h> |
| 19 | +#include <executorch/runtime/platform/compiler.h> |
| 20 | +#include <pytorch/tokenizers/tokenizer.h> |
117 | 21 |
|
118 | | - encoder_output = token_embedding_outputs[0]; |
119 | | - } else { |
120 | | - ET_LOG(Error, "Unsupported input type"); |
121 | | - // For any other input types, return error |
122 | | - return ::executorch::runtime::Error::NotSupported; |
123 | | - } |
124 | | - |
125 | | - // 2. Run decoder model for prefill. |
126 | | - |
127 | | - // Get expected shape of cache position tensor, which should be the second |
128 | | - // argument |
129 | | - |
130 | | - int64_t seq_len = encoder_output.toTensor().size(1); |
131 | | - if (seq_len == 0) { |
132 | | - ET_LOG(Error, "The encoder returned an empty output."); |
133 | | - return ::executorch::runtime::Error::InvalidState; |
134 | | - } |
135 | | - std::vector<int64_t> cache_positions; |
136 | | - |
137 | | - auto cache_position_tensor = ET_UNWRAP(populate_start_pos_or_cache_position( |
138 | | - module_, start_pos, cache_positions, seq_len, kTextModelMethod)); |
139 | | - |
140 | | - auto prefill_result = module_->execute( |
141 | | - kTextModelMethod, {encoder_output, cache_position_tensor}); |
142 | | - if (prefill_result.error() != ::executorch::runtime::Error::Ok) { |
143 | | - return prefill_result.error(); |
144 | | - } |
145 | | - // Check if prefill_outputs is empty, if it is return error and log that the |
146 | | - // specified encoder returned empty results when used to prefill decoder. |
147 | | - auto prefill_outputs = prefill_result.get(); |
148 | | - if (prefill_outputs.empty()) { |
149 | | - ET_LOG( |
150 | | - Error, "Encoder returned empty results when used to prefill decoder"); |
151 | | - return ::executorch::runtime::Error::InvalidState; |
152 | | - } |
153 | | - auto outputs_res = prefill_outputs[0].toTensor(); |
154 | | - |
155 | | - // Update start_pos, tracking the current cache position. |
156 | | - start_pos += seq_len; |
157 | | - |
158 | | - return static_cast<uint64_t>( |
159 | | - text_decoder_runner_->logits_to_token(outputs_res)); |
160 | | -} |
161 | | - |
162 | | -/** |
163 | | - * Load the Module for encoder prefill purpose. |
164 | | - * @return The error code. |
165 | | - */ |
166 | | -::executorch::runtime::Error MultimodalPrefiller::load() { |
167 | | - if (is_method_loaded()) { |
168 | | - return ::executorch::runtime::Error::Ok; |
169 | | - } |
170 | | - // token_embeddings and text_model have to show up in method names. |
171 | | - ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTokenEmbeddingMethod)); |
172 | | - ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kTextModelMethod)); |
173 | | - |
174 | | - std::unordered_set<std::string> methods = |
175 | | - ET_UNWRAP(module_->method_names(), "Failed to get method names"); |
176 | | - |
177 | | - // Load image_encoder method if exists. |
178 | | - if (methods.find(kImageEncoderMethod) != methods.end()) { |
179 | | - ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kImageEncoderMethod)); |
180 | | - } |
181 | | - |
182 | | - if (methods.find(kAudioEncoderMethod) != methods.end()) { |
183 | | - ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kAudioEncoderMethod)); |
184 | | - } |
185 | | - |
186 | | - return ::executorch::runtime::Error::Ok; |
187 | | -} |
| 22 | +namespace executorch::extension::llm { |
188 | 23 |
|
189 | | -/** |
190 | | - * Check if the required methods in the Module is loaded. |
191 | | - * @return True if the Module is loaded, false otherwise. |
192 | | - */ |
193 | | -bool MultimodalPrefiller::is_method_loaded() { |
194 | | - ::executorch::runtime::Result<std::unordered_set<std::string>> methods_res = |
195 | | - module_->method_names(); |
196 | | - if (!module_->is_method_loaded(kTokenEmbeddingMethod)) { |
197 | | - return false; |
198 | | - } |
199 | | - if (!module_->is_method_loaded(kTextModelMethod)) { |
200 | | - return false; |
201 | | - } |
202 | | - if (methods_res.error() != ::executorch::runtime::Error::Ok) { |
203 | | - ET_CHECK_MSG(false, "Failed to get method names"); |
204 | | - } |
205 | | - std::unordered_set<std::string> methods = methods_res.get(); |
206 | | - if (methods.find(kImageEncoderMethod) != methods.end()) { |
207 | | - return module_->is_method_loaded(kImageEncoderMethod); |
208 | | - } |
209 | | - return true; |
210 | | -} |
| 24 | +using runtime::Error; |
| 25 | +using runtime::Result; |
| 26 | +using tokenizers::Tokenizer; |
| 27 | + |
| 28 | +// Assuming kv cache and parallel prefill are enabled. |
| 29 | +// This prefiller supports both image and audio inputs |
| 30 | +class ET_EXPERIMENTAL MultimodalPrefiller { |
| 31 | + public: |
| 32 | + explicit MultimodalPrefiller( |
| 33 | + Module* module, |
| 34 | + MultimodalDecoderRunner* decoder_runner, |
| 35 | + Tokenizer* tokenizer, |
| 36 | + IOManager* io_manager); |
| 37 | + |
| 38 | + /** |
| 39 | + * Prefill an LLM Module with the given multimodal input. |
| 40 | + * @param input The multimodal input (image or audio) to the multimodal LLM. |
| 41 | + * @param start_pos The starting position in KV cache of the input in the LLM. |
| 42 | + * It's passed as reference and will be updated inside this function. |
| 43 | + * @return The next token of the LLM Module after prefill. |
| 44 | + */ |
| 45 | + virtual Result<uint64_t> prefill( |
| 46 | + const MultimodalInput& input, |
| 47 | + int64_t& start_pos); |
| 48 | + |
| 49 | + virtual Error load(); |
| 50 | + virtual bool is_method_loaded(); |
| 51 | + |
| 52 | + virtual ~MultimodalPrefiller() = default; |
| 53 | + |
| 54 | + protected: |
| 55 | + Module* module_; |
| 56 | + MultimodalDecoderRunner* text_decoder_runner_; |
| 57 | + Tokenizer* tokenizer_; |
| 58 | + IOManager* io_manager_; |
| 59 | +}; |
211 | 60 |
|
212 | 61 | } // namespace executorch::extension::llm |
0 commit comments