Skip to content

Commit 9e7a264

Browse files
authored
[multimodal] Add token support to MultimodalInput (#14451)
This pull request adds support for tokenizer-encoded input (as vectors of token IDs) to the `MultimodalInput` class, enabling more flexible and efficient handling of multimodal data. The update includes new constructors, type checks, getters, and factory functions for token inputs, as well as unit tests to ensure correct behavior and compatibility with existing code paths. **MultimodalInput class changes:** * Added a new `TOKENS` type to the `MultimodalInput::Type` enum and updated the internal `std::variant` to support storing `std::vector<uint64_t>` as token data. [[1]](diffhunk://#diff-db31b7448019ab4684675434f5b6e8054ff5d995ffa18e7adee15b5a694a7fb1R34-R73) [[2]](diffhunk://#diff-db31b7448019ab4684675434f5b6e8054ff5d995ffa18e7adee15b5a694a7fb1L290-R367) * Implemented new constructors, type checks (`is_tokens()`), getters (`get_tokens()`), and safe accessors (`try_get_tokens()`) for token inputs, along with static and instance methods for type name conversion. [[1]](diffhunk://#diff-db31b7448019ab4684675434f5b6e8054ff5d995ffa18e7adee15b5a694a7fb1R34-R73) [[2]](diffhunk://#diff-db31b7448019ab4684675434f5b6e8054ff5d995ffa18e7adee15b5a694a7fb1R101-R107) [[3]](diffhunk://#diff-db31b7448019ab4684675434f5b6e8054ff5d995ffa18e7adee15b5a694a7fb1R151-R159) [[4]](diffhunk://#diff-db31b7448019ab4684675434f5b6e8054ff5d995ffa18e7adee15b5a694a7fb1R187-R201) [[5]](diffhunk://#diff-db31b7448019ab4684675434f5b6e8054ff5d995ffa18e7adee15b5a694a7fb1R319-R328) * Added factory functions `make_token_input` for easily creating token-based inputs. **Integration and logging:** * Updated `MultimodalPrefiller::prefill` to handle both text and token inputs, bypassing tokenization when tokens are provided directly. * Added logging in `MultimodalRunner::generate` to include the type name of each input for easier debugging. **Tests:** * Introduced a comprehensive suite of unit tests covering construction, type checking, getters, copy/move semantics, and edge cases for the new token input functionality in `MultimodalInput`.
1 parent 088836c commit 9e7a264

File tree

4 files changed

+357
-5
lines changed

4 files changed

+357
-5
lines changed

extension/llm/runner/multimodal_input.h

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
#include <executorch/extension/llm/runner/audio.h>
1515
#include <executorch/extension/llm/runner/image.h>
1616
#include <executorch/runtime/platform/compiler.h>
17+
#include <cstdint>
1718
#include <string>
1819
#include <variant>
20+
#include <vector>
1921

2022
namespace executorch::extension::llm {
2123

@@ -29,15 +31,46 @@ class ET_EXPERIMENTAL MultimodalInput {
2931
/// Type of multimodal input data
3032
enum class Type {
3133
TEXT, ///< Text string input
34+
TOKENS, ///< Pre-tokenized input (vector of token IDs)
3235
IMAGE, ///< Processed image input
3336
AUDIO, ///< Processed audio input
3437
RAW_AUDIO, ///< Raw unprocessed audio input (straight from audio file)
3538
UNSUPPORTED ///< Unsupported input type
3639
};
3740

41+
/**
42+
* Return a human-readable name for a MultimodalInput::Type.
43+
* Preferred for logging and debugging; returns string literals.
44+
*/
45+
static constexpr const char* TypeName(Type t) noexcept {
46+
switch (t) {
47+
case Type::TEXT:
48+
return "text";
49+
case Type::TOKENS:
50+
return "tokens";
51+
case Type::IMAGE:
52+
return "image";
53+
case Type::AUDIO:
54+
return "audio";
55+
case Type::RAW_AUDIO:
56+
return "raw_audio";
57+
default:
58+
return "unknown";
59+
}
60+
}
61+
62+
/** Convenience wrapper that returns a std::string. */
63+
static inline std::string TypeToString(Type t) {
64+
return TypeName(t);
65+
}
66+
3867
// Constructors
3968
explicit MultimodalInput(const std::string& text) : data_(text) {}
4069
explicit MultimodalInput(std::string&& text) : data_(std::move(text)) {}
70+
explicit MultimodalInput(const std::vector<uint64_t>& tokens)
71+
: data_(tokens) {}
72+
explicit MultimodalInput(std::vector<uint64_t>&& tokens)
73+
: data_(std::move(tokens)) {}
4174
explicit MultimodalInput(const Image& image) : data_(image) {}
4275
explicit MultimodalInput(Image&& image) : data_(std::move(image)) {}
4376
explicit MultimodalInput(const Audio& audio) : data_(audio) {}
@@ -65,6 +98,13 @@ class ET_EXPERIMENTAL MultimodalInput {
6598
return std::holds_alternative<std::string>(data_);
6699
}
67100

101+
/**
102+
* Check if this input contains pre-tokenized data.
103+
*/
104+
bool is_tokens() const noexcept {
105+
return std::holds_alternative<std::vector<uint64_t>>(data_);
106+
}
107+
68108
/**
69109
* Check if this input contains image data.
70110
* @return true if this input contains an image, false otherwise.
@@ -97,6 +137,8 @@ class ET_EXPERIMENTAL MultimodalInput {
97137
Type get_type() const noexcept {
98138
if (is_text())
99139
return Type::TEXT;
140+
if (is_tokens())
141+
return Type::TOKENS;
100142
if (is_image())
101143
return Type::IMAGE;
102144
if (is_audio())
@@ -106,6 +148,15 @@ class ET_EXPERIMENTAL MultimodalInput {
106148
return Type::UNSUPPORTED;
107149
}
108150

151+
/**
152+
* Get a human-readable name for the contained input type.
153+
* Returns one of: "text", "tokens", "image", "audio", "raw_audio", or
154+
* "unknown".
155+
*/
156+
const char* type_name() const noexcept {
157+
return TypeName(get_type());
158+
}
159+
109160
/**
110161
* Get the text data from this input.
111162
* @return Reference to the stored text string.
@@ -133,6 +184,21 @@ class ET_EXPERIMENTAL MultimodalInput {
133184
return std::get<std::string>(std::move(data_));
134185
}
135186

187+
/**
188+
* Get the token vector from this input.
189+
*/
190+
const std::vector<uint64_t>& get_tokens() const& {
191+
return std::get<std::vector<uint64_t>>(data_);
192+
}
193+
194+
std::vector<uint64_t>& get_tokens() & {
195+
return std::get<std::vector<uint64_t>>(data_);
196+
}
197+
198+
std::vector<uint64_t>&& get_tokens() && {
199+
return std::get<std::vector<uint64_t>>(std::move(data_));
200+
}
201+
136202
/**
137203
* Get the image data from this input.
138204
* @return Reference to the stored Image object.
@@ -250,6 +316,16 @@ class ET_EXPERIMENTAL MultimodalInput {
250316
return std::get_if<Image>(&data_);
251317
}
252318

319+
/** Try to get the tokens from this input safely. */
320+
const std::vector<uint64_t>* try_get_tokens() const noexcept {
321+
return std::get_if<std::vector<uint64_t>>(&data_);
322+
}
323+
324+
/** Try to get the tokens from this input safely (mutable). */
325+
std::vector<uint64_t>* try_get_tokens() noexcept {
326+
return std::get_if<std::vector<uint64_t>>(&data_);
327+
}
328+
253329
/**
254330
* Try to get the audio data from this input safely.
255331
* @return Pointer to the Audio object if this input contains audio,
@@ -287,7 +363,8 @@ class ET_EXPERIMENTAL MultimodalInput {
287363
}
288364

289365
private:
290-
std::variant<std::string, Image, Audio, RawAudio> data_;
366+
std::variant<std::string, std::vector<uint64_t>, Image, Audio, RawAudio>
367+
data_;
291368
};
292369

293370
// Convenience factory functions
@@ -307,6 +384,16 @@ inline MultimodalInput make_image_input(Image&& image) noexcept {
307384
return MultimodalInput(std::move(image));
308385
}
309386

387+
inline MultimodalInput make_token_input(
388+
const std::vector<uint64_t>& tokens) noexcept {
389+
return MultimodalInput(tokens);
390+
}
391+
392+
inline MultimodalInput make_token_input(
393+
std::vector<uint64_t>&& tokens) noexcept {
394+
return MultimodalInput(std::move(tokens));
395+
}
396+
310397
inline MultimodalInput make_audio_input(const Audio& audio) noexcept {
311398
return MultimodalInput(audio);
312399
}

extension/llm/runner/multimodal_prefiller.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,14 @@ Result<uint64_t> MultimodalPrefiller::prefill(
110110
auto audio_encoder_outputs = audio_encoder_result.get();
111111

112112
encoder_output = audio_encoder_outputs[0];
113-
} else if (input.is_text()) {
114-
auto& text = input.get_text();
115-
std::vector<uint64_t> tokens =
116-
ET_UNWRAP_TOKENIZER(tokenizer_->encode(text));
113+
} else if (input.is_text() || input.is_tokens()) {
114+
std::vector<uint64_t> tokens;
115+
if (input.is_text()) {
116+
auto& text = input.get_text();
117+
tokens = ET_UNWRAP_TOKENIZER(tokenizer_->encode(text));
118+
} else {
119+
tokens = input.get_tokens();
120+
}
117121

118122
auto text_tensor = executorch::extension::from_blob(
119123
tokens.data(),

extension/llm/runner/multimodal_runner.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,12 @@ Error MultimodalRunner::generate(
116116
// Process multimodal inputs in order
117117
for (size_t i = 0; i < inputs.size(); ++i) {
118118
const MultimodalInput& input = inputs[i];
119+
ET_LOG(
120+
Info,
121+
"Prefilling input %zu/%zu, type: %s",
122+
i,
123+
inputs.size(),
124+
input.type_name());
119125
if (config.echo && i == inputs.size() - 1 && input.is_text()) {
120126
wrapped_callback(input.get_text());
121127
}

0 commit comments

Comments
 (0)