Skip to content

Commit 71dea2a

Browse files
committed
[multimodal] Add token support to MultimodalInput
1 parent 07d1092 commit 71dea2a

File tree

4 files changed

+357
-5
lines changed

4 files changed

+357
-5
lines changed

extension/llm/runner/multimodal_input.h

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
#include <executorch/extension/llm/runner/audio.h>
1515
#include <executorch/extension/llm/runner/image.h>
1616
#include <executorch/runtime/platform/compiler.h>
17+
#include <cstdint>
1718
#include <string>
1819
#include <variant>
20+
#include <vector>
1921

2022
namespace executorch::extension::llm {
2123

@@ -29,15 +31,46 @@ class ET_EXPERIMENTAL MultimodalInput {
2931
/// Type of multimodal input data
3032
enum class Type {
3133
TEXT, ///< Text string input
34+
TOKENS, ///< Pre-tokenized input (vector of token IDs)
3235
IMAGE, ///< Processed image input
3336
AUDIO, ///< Processed audio input
3437
RAW_AUDIO, ///< Raw unprocessed audio input (straight from audio file)
3538
UNSUPPORTED ///< Unsupported input type
3639
};
3740

41+
/**
42+
* Return a human-readable name for a MultimodalInput::Type.
43+
* Preferred for logging and debugging; returns string literals.
44+
*/
45+
static constexpr const char* TypeName(Type t) noexcept {
46+
switch (t) {
47+
case Type::TEXT:
48+
return "text";
49+
case Type::TOKENS:
50+
return "tokens";
51+
case Type::IMAGE:
52+
return "image";
53+
case Type::AUDIO:
54+
return "audio";
55+
case Type::RAW_AUDIO:
56+
return "raw_audio";
57+
default:
58+
return "unknown";
59+
}
60+
}
61+
62+
/** Convenience wrapper that returns a std::string. */
63+
static inline std::string TypeToString(Type t) {
64+
return TypeName(t);
65+
}
66+
3867
// Constructors
3968
explicit MultimodalInput(const std::string& text) : data_(text) {}
4069
explicit MultimodalInput(std::string&& text) : data_(std::move(text)) {}
70+
explicit MultimodalInput(const std::vector<uint64_t>& tokens)
71+
: data_(tokens) {}
72+
explicit MultimodalInput(std::vector<uint64_t>&& tokens)
73+
: data_(std::move(tokens)) {}
4174
explicit MultimodalInput(const Image& image) : data_(image) {}
4275
explicit MultimodalInput(Image&& image) : data_(std::move(image)) {}
4376
explicit MultimodalInput(const Audio& audio) : data_(audio) {}
@@ -65,6 +98,13 @@ class ET_EXPERIMENTAL MultimodalInput {
6598
return std::holds_alternative<std::string>(data_);
6699
}
67100

101+
/**
102+
* Check if this input contains pre-tokenized data.
103+
*/
104+
bool is_tokens() const noexcept {
105+
return std::holds_alternative<std::vector<uint64_t>>(data_);
106+
}
107+
68108
/**
69109
* Check if this input contains image data.
70110
* @return true if this input contains an image, false otherwise.
@@ -97,6 +137,8 @@ class ET_EXPERIMENTAL MultimodalInput {
97137
Type get_type() const noexcept {
98138
if (is_text())
99139
return Type::TEXT;
140+
if (is_tokens())
141+
return Type::TOKENS;
100142
if (is_image())
101143
return Type::IMAGE;
102144
if (is_audio())
@@ -106,6 +148,15 @@ class ET_EXPERIMENTAL MultimodalInput {
106148
return Type::UNSUPPORTED;
107149
}
108150

151+
/**
152+
* Get a human-readable name for the contained input type.
153+
* Returns one of: "text", "tokens", "image", "audio", "raw_audio", or
154+
* "unknown".
155+
*/
156+
const char* type_name() const noexcept {
157+
return TypeName(get_type());
158+
}
159+
109160
/**
110161
* Get the text data from this input.
111162
* @return Reference to the stored text string.
@@ -133,6 +184,21 @@ class ET_EXPERIMENTAL MultimodalInput {
133184
return std::get<std::string>(std::move(data_));
134185
}
135186

187+
/**
188+
* Get the token vector from this input.
189+
*/
190+
const std::vector<uint64_t>& get_tokens() const& {
191+
return std::get<std::vector<uint64_t>>(data_);
192+
}
193+
194+
std::vector<uint64_t>& get_tokens() & {
195+
return std::get<std::vector<uint64_t>>(data_);
196+
}
197+
198+
std::vector<uint64_t>&& get_tokens() && {
199+
return std::get<std::vector<uint64_t>>(std::move(data_));
200+
}
201+
136202
/**
137203
* Get the image data from this input.
138204
* @return Reference to the stored Image object.
@@ -250,6 +316,16 @@ class ET_EXPERIMENTAL MultimodalInput {
250316
return std::get_if<Image>(&data_);
251317
}
252318

319+
/** Try to get the tokens from this input safely. */
320+
const std::vector<uint64_t>* try_get_tokens() const noexcept {
321+
return std::get_if<std::vector<uint64_t>>(&data_);
322+
}
323+
324+
/** Try to get the tokens from this input safely (mutable). */
325+
std::vector<uint64_t>* try_get_tokens() noexcept {
326+
return std::get_if<std::vector<uint64_t>>(&data_);
327+
}
328+
253329
/**
254330
* Try to get the audio data from this input safely.
255331
* @return Pointer to the Audio object if this input contains audio,
@@ -287,7 +363,8 @@ class ET_EXPERIMENTAL MultimodalInput {
287363
}
288364

289365
private:
290-
std::variant<std::string, Image, Audio, RawAudio> data_;
366+
std::variant<std::string, std::vector<uint64_t>, Image, Audio, RawAudio>
367+
data_;
291368
};
292369

293370
// Convenience factory functions
@@ -307,6 +384,16 @@ inline MultimodalInput make_image_input(Image&& image) noexcept {
307384
return MultimodalInput(std::move(image));
308385
}
309386

387+
inline MultimodalInput make_token_input(
388+
const std::vector<uint64_t>& tokens) noexcept {
389+
return MultimodalInput(tokens);
390+
}
391+
392+
inline MultimodalInput make_token_input(
393+
std::vector<uint64_t>&& tokens) noexcept {
394+
return MultimodalInput(std::move(tokens));
395+
}
396+
310397
inline MultimodalInput make_audio_input(const Audio& audio) noexcept {
311398
return MultimodalInput(audio);
312399
}

extension/llm/runner/multimodal_prefiller.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,14 @@ Result<uint64_t> MultimodalPrefiller::prefill(
110110
auto audio_encoder_outputs = audio_encoder_result.get();
111111

112112
encoder_output = audio_encoder_outputs[0];
113-
} else if (input.is_text()) {
114-
auto& text = input.get_text();
115-
std::vector<uint64_t> tokens =
116-
ET_UNWRAP_TOKENIZER(tokenizer_->encode(text));
113+
} else if (input.is_text() || input.is_tokens()) {
114+
std::vector<uint64_t> tokens;
115+
if (input.is_text()) {
116+
auto& text = input.get_text();
117+
tokens = ET_UNWRAP_TOKENIZER(tokenizer_->encode(text));
118+
} else {
119+
tokens = input.get_tokens();
120+
}
117121

118122
auto text_tensor = executorch::extension::from_blob(
119123
tokens.data(),

extension/llm/runner/multimodal_runner.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,12 @@ Error MultimodalRunner::generate(
116116
// Process multimodal inputs in order
117117
for (size_t i = 0; i < inputs.size(); ++i) {
118118
const MultimodalInput& input = inputs[i];
119+
ET_LOG(
120+
Info,
121+
"Prefilling input %zu/%zu, type: %s",
122+
i,
123+
inputs.size(),
124+
input.type_name());
119125
if (config.echo && i == inputs.size() - 1 && input.is_text()) {
120126
wrapped_callback(input.get_text());
121127
}

0 commit comments

Comments
 (0)