Skip to content

Commit 39594fa

Browse files
Add audio to multimodal runner (#13948)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #13870 by @jackzhxng ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/jackzhxng/34/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/jackzhxng/34/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/jackzhxng/34/orig @diff-train-skip-merge Co-authored-by: Jack Zhang <[email protected]>
1 parent c5ff74c commit 39594fa

File tree

9 files changed

+282
-31
lines changed

9 files changed

+282
-31
lines changed

examples/models/llava/export_llava.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,11 @@ def export_all(llava_model: LlavaModel):
226226
{
227227
"image_encoder": image_encoder_ep,
228228
"token_embedding": token_embedding_ep,
229-
"text_model": text_model_ep,
229+
"text_decoder": text_model_ep,
230230
},
231231
partitioner={
232232
"image_encoder": [XnnpackPartitioner()],
233-
"text_model": [
233+
"text_decoder": [
234234
# First partition the DQLinear nodes, then partition the rest of the nodes,
235235
# to avoid multiple DQLinear nodes in the same partition,
236236
# to avoid holding multiple unpacked and packed weight buffers in memory,
@@ -254,7 +254,7 @@ def export_all(llava_model: LlavaModel):
254254
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
255255
sym_shape_eval_pass={
256256
"image_encoder": ConstraintBasedSymShapeEvalPass(),
257-
"text_model": ConstraintBasedSymShapeEvalPass(),
257+
"text_decoder": ConstraintBasedSymShapeEvalPass(),
258258
"token_embedding": HintBasedSymShapeEvalPass(),
259259
},
260260
)

examples/models/llava/runner/llava_text_decoder_runner.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class ET_EXPERIMENTAL LlavaTextDecoderRunner
8989
}
9090

9191
inline static const std::string kTokenEmbeddingMethod = "token_embedding";
92-
inline static const std::string kTextModelMethod = "text_model";
92+
inline static const std::string kTextModelMethod = "text_decoder";
9393
};
9494

9595
} // namespace example

examples/models/llava/test/test_llava.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def test_llava_export(self):
9696
"token_embedding", (prompt_before_image,)
9797
)[0]
9898
llava_module.run_method(
99-
"text_model",
99+
"text_decoder",
100100
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img),
101101
)
102102

@@ -107,7 +107,7 @@ def test_llava_export(self):
107107
# pte prefill image
108108
pte_embeds_img = llava_module.run_method("image_encoder", (resized,))[0]
109109
llava_module.run_method(
110-
"text_model",
110+
"text_decoder",
111111
(
112112
torch.tensor([start_pos], dtype=torch.int64),
113113
pte_embeds_img,
@@ -122,7 +122,7 @@ def test_llava_export(self):
122122
"token_embedding", (prompt_after_image,)
123123
)[0]
124124
pte_prefill_after_img = llava_module.run_method(
125-
"text_model",
125+
"text_decoder",
126126
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img),
127127
)[0]
128128

@@ -139,7 +139,7 @@ def test_llava_export(self):
139139
"token_embedding", (torch.tensor([[new_tokens[i]]], dtype=torch.int64),)
140140
)[0]
141141
logits = llava_module.run_method(
142-
"text_model",
142+
"text_decoder",
143143
(torch.tensor([start_pos + i], dtype=torch.int64), token_embeds),
144144
)[0]
145145
new_tokens.append(torch.argmax(logits).item())

examples/models/llava/test/test_pte.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def main():
4747
"token_embedding", (prompt_before_image,)
4848
)[0]
4949
pte_prefill_before_img = llava_module.run_method(
50-
"text_model",
50+
"text_decoder",
5151
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img),
5252
)[0]
5353
print(pte_prefill_before_img)
@@ -60,7 +60,7 @@ def main():
6060
logging.warning("Image encoder finished")
6161
logging.warning("Image token prefill started")
6262
pte_prefill_img = llava_module.run_method(
63-
"text_model",
63+
"text_decoder",
6464
(
6565
torch.tensor([start_pos], dtype=torch.int64),
6666
pte_embeds_img,
@@ -77,7 +77,7 @@ def main():
7777
"token_embedding", (prompt_after_image,)
7878
)[0]
7979
pte_prefill_after_img = llava_module.run_method(
80-
"text_model",
80+
"text_decoder",
8181
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img),
8282
)[0]
8383
logging.warning("Text token prefill finished")
@@ -91,7 +91,7 @@ def main():
9191
"token_embedding", (torch.tensor([[new_tokens[i]]], dtype=torch.int64),)
9292
)[0]
9393
logits = llava_module.run_method(
94-
"text_model",
94+
"text_decoder",
9595
(torch.tensor([start_pos + i], dtype=torch.int64), token_embeds),
9696
)[0]
9797
new_tokens.append(torch.argmax(logits[..., -1, :]).item())

extension/llm/runner/audio.h

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
// A simple audio struct.
10+
11+
#pragma once
12+
#include <executorch/runtime/platform/compiler.h>
13+
#include <cstdint>
14+
#include <vector>
15+
16+
namespace executorch {
17+
namespace extension {
18+
namespace llm {
19+
20+
/**
21+
* Audio inputs as a raw audio tensor, for use when the audio processing
22+
* into a mel spectrogram is baked into the audio encoder with torch.export.
23+
*/
24+
struct ET_EXPERIMENTAL RawAudio {
25+
std::vector<uint8_t> data;
26+
int32_t batch_size;
27+
int32_t n_channels; // For mono, use n_channels = 1.
28+
int32_t n_samples;
29+
};
30+
31+
/**
32+
* Pre-processed audio inputs, ready to feed directly into an audio
33+
* encoder.
34+
*/
35+
struct ET_EXPERIMENTAL Audio {
36+
std::vector<uint8_t> data;
37+
int32_t batch_size;
38+
int32_t n_bins;
39+
int32_t n_frames;
40+
};
41+
42+
} // namespace llm
43+
} // namespace extension
44+
} // namespace executorch
45+
46+
namespace torch {
47+
namespace executor {
48+
// TODO(T197294990): Remove these deprecated aliases once all users have moved
49+
// to the new `::executorch` namespaces.
50+
using ::executorch::extension::llm::Audio;
51+
} // namespace executor
52+
} // namespace torch

extension/llm/runner/constants.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ inline constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";
2121

2222
// Multimodal method name conventions
2323
inline constexpr auto kImageEncoderMethod = "image_encoder";
24+
inline constexpr auto kAudioEncoderMethod = "audio_encoder";
2425
inline constexpr auto kTokenEmbeddingMethod = "token_embedding";
25-
inline constexpr auto kTextModelMethod = "text_model";
26+
inline constexpr auto kTextModelMethod = "text_decoder";
2627

2728
} // namespace executorch::extension::llm

extension/llm/runner/multimodal_input.h

Lines changed: 153 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#pragma once
1313

14+
#include <executorch/extension/llm/runner/audio.h>
1415
#include <executorch/extension/llm/runner/image.h>
1516
#include <executorch/runtime/platform/compiler.h>
1617
#include <string>
@@ -19,19 +20,31 @@
1920
namespace executorch::extension::llm {
2021

2122
/**
22-
* A generic class to hold either image or text data for multimodal inputs.
23-
* This allows the generate() API to take a std::vector of these objects
24-
* instead of separate image and text parameters.
23+
* A generic class to hold either image, text, or audio data for multimodal
24+
* inputs. This allows the generate() API to take a std::vector of these objects
25+
* instead of separate image, text, and audio parameters.
2526
*/
2627
class ET_EXPERIMENTAL MultimodalInput {
2728
public:
28-
enum class Type { TEXT, IMAGE };
29+
/// Type of multimodal input data
30+
enum class Type {
31+
TEXT, ///< Text string input
32+
IMAGE, ///< Processed image input
33+
AUDIO, ///< Processed audio input
34+
RAW_AUDIO, ///< Raw unprocessed audio input (straight from audio file)
35+
UNSUPPORTED ///< Unsupported input type
36+
};
2937

3038
// Constructors
3139
explicit MultimodalInput(const std::string& text) : data_(text) {}
3240
explicit MultimodalInput(std::string&& text) : data_(std::move(text)) {}
3341
explicit MultimodalInput(const Image& image) : data_(image) {}
3442
explicit MultimodalInput(Image&& image) : data_(std::move(image)) {}
43+
explicit MultimodalInput(const Audio& audio) : data_(audio) {}
44+
explicit MultimodalInput(Audio&& audio) : data_(std::move(audio)) {}
45+
explicit MultimodalInput(const RawAudio& raw_audio) : data_(raw_audio) {}
46+
explicit MultimodalInput(RawAudio&& raw_audio)
47+
: data_(std::move(raw_audio)) {}
3548

3649
// Copy constructor and assignment
3750
MultimodalInput(const MultimodalInput& other) = default;
@@ -60,12 +73,37 @@ class ET_EXPERIMENTAL MultimodalInput {
6073
return std::holds_alternative<Image>(data_);
6174
}
6275

76+
/**
77+
* Check if this input contains audio data.
78+
* @return true if this input contains audio, false otherwise.
79+
*/
80+
bool is_audio() const noexcept {
81+
return std::holds_alternative<Audio>(data_);
82+
}
83+
84+
/**
85+
* Check if this input contains raw audio data.
86+
* @return true if this input contains raw audio, false otherwise.
87+
*/
88+
bool is_raw_audio() const noexcept {
89+
return std::holds_alternative<RawAudio>(data_);
90+
}
91+
6392
/**
6493
* Get the type of data stored in this input.
65-
* @return Type::TEXT if text data, Type::IMAGE if image data.
94+
* @return Type::TEXT if text data, Type::IMAGE if image data, Type::AUDIO if
95+
* audio data, Type::RAW_AUDIO if raw audio data.
6696
*/
6797
Type get_type() const noexcept {
68-
return is_text() ? Type::TEXT : Type::IMAGE;
98+
if (is_text())
99+
return Type::TEXT;
100+
if (is_image())
101+
return Type::IMAGE;
102+
if (is_audio())
103+
return Type::AUDIO;
104+
if (is_raw_audio())
105+
return Type::RAW_AUDIO;
106+
return Type::UNSUPPORTED;
69107
}
70108

71109
/**
@@ -122,6 +160,60 @@ class ET_EXPERIMENTAL MultimodalInput {
122160
return std::get<Image>(std::move(data_));
123161
}
124162

163+
/**
164+
* Get the audio data from this input.
165+
* @return Reference to the stored Audio object.
166+
* @throws std::bad_variant_access if this input doesn't contain audio.
167+
*/
168+
const Audio& get_audio() const& {
169+
return std::get<Audio>(data_);
170+
}
171+
172+
/**
173+
* Get the audio data from this input (mutable version).
174+
* @return Mutable reference to the stored Audio object.
175+
* @throws std::bad_variant_access if this input doesn't contain audio.
176+
*/
177+
Audio& get_audio() & {
178+
return std::get<Audio>(data_);
179+
}
180+
181+
/**
182+
* Get the audio data from this input (rvalue version).
183+
* @return Rvalue reference to the stored Audio object for efficient moves.
184+
* @throws std::bad_variant_access if this input doesn't contain audio.
185+
*/
186+
Audio&& get_audio() && {
187+
return std::get<Audio>(std::move(data_));
188+
}
189+
190+
/**
191+
* Get the raw audio data from this input.
192+
* @return Reference to the stored RawAudio object.
193+
* @throws std::bad_variant_access if this input doesn't contain raw audio.
194+
*/
195+
const RawAudio& get_raw_audio() const& {
196+
return std::get<RawAudio>(data_);
197+
}
198+
199+
/**
200+
* Get the raw audio data from this input (mutable version).
201+
* @return Mutable reference to the stored RawAudio object.
202+
* @throws std::bad_variant_access if this input doesn't contain raw audio.
203+
*/
204+
RawAudio& get_raw_audio() & {
205+
return std::get<RawAudio>(data_);
206+
}
207+
208+
/**
209+
* Get the raw audio data from this input (rvalue version).
210+
* @return Rvalue reference to the stored RawAudio object for efficient moves.
211+
* @throws std::bad_variant_access if this input doesn't contain raw audio.
212+
*/
213+
RawAudio&& get_raw_audio() && {
214+
return std::get<RawAudio>(std::move(data_));
215+
}
216+
125217
/**
126218
* Try to get the text data from this input safely.
127219
* @return Pointer to the text string if this input contains text, nullptr
@@ -158,8 +250,44 @@ class ET_EXPERIMENTAL MultimodalInput {
158250
return std::get_if<Image>(&data_);
159251
}
160252

253+
/**
254+
* Try to get the audio data from this input safely.
255+
* @return Pointer to the Audio object if this input contains audio,
256+
* nullptr otherwise.
257+
*/
258+
const Audio* try_get_audio() const noexcept {
259+
return std::get_if<Audio>(&data_);
260+
}
261+
262+
/**
263+
* Try to get the audio data from this input safely (mutable version).
264+
* @return Pointer to the Audio object if this input contains audio,
265+
* nullptr otherwise.
266+
*/
267+
Audio* try_get_audio() noexcept {
268+
return std::get_if<Audio>(&data_);
269+
}
270+
271+
/**
272+
* Try to get the raw audio data from this input safely.
273+
* @return Pointer to the RawAudio object if this input contains raw audio,
274+
* nullptr otherwise.
275+
*/
276+
const RawAudio* try_get_raw_audio() const noexcept {
277+
return std::get_if<RawAudio>(&data_);
278+
}
279+
280+
/**
281+
* Try to get the raw audio data from this input safely (mutable version).
282+
* @return Pointer to the RawAudio object if this input contains raw audio,
283+
* nullptr otherwise.
284+
*/
285+
RawAudio* try_get_raw_audio() noexcept {
286+
return std::get_if<RawAudio>(&data_);
287+
}
288+
161289
private:
162-
std::variant<std::string, Image> data_;
290+
std::variant<std::string, Image, Audio, RawAudio> data_;
163291
};
164292

165293
// Convenience factory functions
@@ -179,4 +307,21 @@ inline MultimodalInput make_image_input(Image&& image) noexcept {
179307
return MultimodalInput(std::move(image));
180308
}
181309

182-
} // namespace executorch::extension::llm
310+
inline MultimodalInput make_audio_input(const Audio& audio) noexcept {
311+
return MultimodalInput(audio);
312+
}
313+
314+
inline MultimodalInput make_audio_input(Audio&& audio) noexcept {
315+
return MultimodalInput(std::move(audio));
316+
}
317+
318+
inline MultimodalInput make_raw_audio_input(
319+
const RawAudio& raw_audio) noexcept {
320+
return MultimodalInput(raw_audio);
321+
}
322+
323+
inline MultimodalInput make_raw_audio_input(RawAudio&& raw_audio) noexcept {
324+
return MultimodalInput(std::move(raw_audio));
325+
}
326+
327+
} // namespace executorch::extension::llm

0 commit comments

Comments
 (0)