11
11
12
12
#pragma once
13
13
14
+ #include < executorch/extension/llm/runner/audio.h>
14
15
#include < executorch/extension/llm/runner/image.h>
15
16
#include < executorch/runtime/platform/compiler.h>
16
17
#include < string>
19
20
namespace executorch ::extension::llm {
20
21
21
22
/* *
22
- * A generic class to hold either image or text data for multimodal inputs.
23
- * This allows the generate() API to take a std::vector of these objects
24
- * instead of separate image and text parameters.
23
+ * A generic class to hold either image, text, or audio data for multimodal
24
+ * inputs. This allows the generate() API to take a std::vector of these objects
25
+ * instead of separate image, text, and audio parameters.
25
26
*/
26
27
class ET_EXPERIMENTAL MultimodalInput {
27
28
public:
28
- enum class Type { TEXT, IMAGE };
29
+ // / Type of multimodal input data
30
+ enum class Type {
31
+ TEXT, // /< Text string input
32
+ IMAGE, // /< Processed image input
33
+ AUDIO, // /< Processed audio input
34
+ RAW_AUDIO, // /< Raw unprocessed audio input (straight from audio file)
35
+ UNSUPPORTED // /< Unsupported input type
36
+ };
29
37
30
38
// Constructors
31
39
explicit MultimodalInput (const std::string& text) : data_(text) {}
32
40
explicit MultimodalInput (std::string&& text) : data_(std::move(text)) {}
33
41
explicit MultimodalInput (const Image& image) : data_(image) {}
34
42
explicit MultimodalInput (Image&& image) : data_(std::move(image)) {}
43
+ explicit MultimodalInput (const Audio& audio) : data_(audio) {}
44
+ explicit MultimodalInput (Audio&& audio) : data_(std::move(audio)) {}
45
+ explicit MultimodalInput (const RawAudio& raw_audio) : data_(raw_audio) {}
46
+ explicit MultimodalInput (RawAudio&& raw_audio)
47
+ : data_(std::move(raw_audio)) {}
35
48
36
49
// Copy constructor and assignment
37
50
MultimodalInput (const MultimodalInput& other) = default ;
@@ -60,12 +73,37 @@ class ET_EXPERIMENTAL MultimodalInput {
60
73
return std::holds_alternative<Image>(data_);
61
74
}
62
75
76
+ /* *
77
+ * Check if this input contains audio data.
78
+ * @return true if this input contains audio, false otherwise.
79
+ */
80
+ bool is_audio () const noexcept {
81
+ return std::holds_alternative<Audio>(data_);
82
+ }
83
+
84
+ /* *
85
+ * Check if this input contains raw audio data.
86
+ * @return true if this input contains raw audio, false otherwise.
87
+ */
88
+ bool is_raw_audio () const noexcept {
89
+ return std::holds_alternative<RawAudio>(data_);
90
+ }
91
+
63
92
/* *
64
93
* Get the type of data stored in this input.
65
- * @return Type::TEXT if text data, Type::IMAGE if image data.
94
+ * @return Type::TEXT if text data, Type::IMAGE if image data, Type::AUDIO if
95
+ * audio data, Type::RAW_AUDIO if raw audio data.
66
96
*/
67
97
Type get_type () const noexcept {
68
- return is_text () ? Type::TEXT : Type::IMAGE;
98
+ if (is_text ())
99
+ return Type::TEXT;
100
+ if (is_image ())
101
+ return Type::IMAGE;
102
+ if (is_audio ())
103
+ return Type::AUDIO;
104
+ if (is_raw_audio ())
105
+ return Type::RAW_AUDIO;
106
+ return Type::UNSUPPORTED;
69
107
}
70
108
71
109
/* *
@@ -122,6 +160,60 @@ class ET_EXPERIMENTAL MultimodalInput {
122
160
return std::get<Image>(std::move (data_));
123
161
}
124
162
163
+ /* *
164
+ * Get the audio data from this input.
165
+ * @return Reference to the stored Audio object.
166
+ * @throws std::bad_variant_access if this input doesn't contain audio.
167
+ */
168
+ const Audio& get_audio () const & {
169
+ return std::get<Audio>(data_);
170
+ }
171
+
172
+ /* *
173
+ * Get the audio data from this input (mutable version).
174
+ * @return Mutable reference to the stored Audio object.
175
+ * @throws std::bad_variant_access if this input doesn't contain audio.
176
+ */
177
+ Audio& get_audio () & {
178
+ return std::get<Audio>(data_);
179
+ }
180
+
181
+ /* *
182
+ * Get the audio data from this input (rvalue version).
183
+ * @return Rvalue reference to the stored Audio object for efficient moves.
184
+ * @throws std::bad_variant_access if this input doesn't contain audio.
185
+ */
186
+ Audio&& get_audio() && {
187
+ return std::get<Audio>(std::move (data_));
188
+ }
189
+
190
+ /* *
191
+ * Get the raw audio data from this input.
192
+ * @return Reference to the stored RawAudio object.
193
+ * @throws std::bad_variant_access if this input doesn't contain raw audio.
194
+ */
195
+ const RawAudio& get_raw_audio () const & {
196
+ return std::get<RawAudio>(data_);
197
+ }
198
+
199
+ /* *
200
+ * Get the raw audio data from this input (mutable version).
201
+ * @return Mutable reference to the stored RawAudio object.
202
+ * @throws std::bad_variant_access if this input doesn't contain raw audio.
203
+ */
204
+ RawAudio& get_raw_audio () & {
205
+ return std::get<RawAudio>(data_);
206
+ }
207
+
208
+ /* *
209
+ * Get the raw audio data from this input (rvalue version).
210
+ * @return Rvalue reference to the stored RawAudio object for efficient moves.
211
+ * @throws std::bad_variant_access if this input doesn't contain raw audio.
212
+ */
213
+ RawAudio&& get_raw_audio() && {
214
+ return std::get<RawAudio>(std::move (data_));
215
+ }
216
+
125
217
/* *
126
218
* Try to get the text data from this input safely.
127
219
* @return Pointer to the text string if this input contains text, nullptr
@@ -158,8 +250,44 @@ class ET_EXPERIMENTAL MultimodalInput {
158
250
return std::get_if<Image>(&data_);
159
251
}
160
252
253
+ /* *
254
+ * Try to get the audio data from this input safely.
255
+ * @return Pointer to the Audio object if this input contains audio,
256
+ * nullptr otherwise.
257
+ */
258
+ const Audio* try_get_audio () const noexcept {
259
+ return std::get_if<Audio>(&data_);
260
+ }
261
+
262
+ /* *
263
+ * Try to get the audio data from this input safely (mutable version).
264
+ * @return Pointer to the Audio object if this input contains audio,
265
+ * nullptr otherwise.
266
+ */
267
+ Audio* try_get_audio () noexcept {
268
+ return std::get_if<Audio>(&data_);
269
+ }
270
+
271
+ /* *
272
+ * Try to get the raw audio data from this input safely.
273
+ * @return Pointer to the RawAudio object if this input contains raw audio,
274
+ * nullptr otherwise.
275
+ */
276
+ const RawAudio* try_get_raw_audio () const noexcept {
277
+ return std::get_if<RawAudio>(&data_);
278
+ }
279
+
280
+ /* *
281
+ * Try to get the raw audio data from this input safely (mutable version).
282
+ * @return Pointer to the RawAudio object if this input contains raw audio,
283
+ * nullptr otherwise.
284
+ */
285
+ RawAudio* try_get_raw_audio () noexcept {
286
+ return std::get_if<RawAudio>(&data_);
287
+ }
288
+
161
289
private:
162
- std::variant<std::string, Image> data_;
290
+ std::variant<std::string, Image, Audio, RawAudio > data_;
163
291
};
164
292
165
293
// Convenience factory functions
@@ -179,4 +307,21 @@ inline MultimodalInput make_image_input(Image&& image) noexcept {
179
307
return MultimodalInput (std::move (image));
180
308
}
181
309
182
- } // namespace executorch::extension::llm
310
+ inline MultimodalInput make_audio_input (const Audio& audio) noexcept {
311
+ return MultimodalInput (audio);
312
+ }
313
+
314
+ inline MultimodalInput make_audio_input (Audio&& audio) noexcept {
315
+ return MultimodalInput (std::move (audio));
316
+ }
317
+
318
+ inline MultimodalInput make_raw_audio_input (
319
+ const RawAudio& raw_audio) noexcept {
320
+ return MultimodalInput (raw_audio);
321
+ }
322
+
323
+ inline MultimodalInput make_raw_audio_input (RawAudio&& raw_audio) noexcept {
324
+ return MultimodalInput (std::move (raw_audio));
325
+ }
326
+
327
+ } // namespace executorch::extension::llm
0 commit comments