-
Notifications
You must be signed in to change notification settings - Fork 698
Add audio to multimodal runner #13662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
127ff2e
9520ed2
9d82591
4640124
b9feedf
7207d1d
be6eb00
fb87bbf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,7 @@ | |
|
|
||
| #pragma once | ||
|
|
||
| #include <executorch/extension/llm/runner/audio.h> | ||
jackzhxng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| #include <executorch/extension/llm/runner/image.h> | ||
| #include <executorch/runtime/platform/compiler.h> | ||
| #include <string> | ||
|
|
@@ -19,19 +20,24 @@ | |
| namespace executorch::extension::llm { | ||
|
|
||
| /** | ||
| * A generic class to hold either image or text data for multimodal inputs. | ||
| * This allows the generate() API to take a std::vector of these objects | ||
| * instead of separate image and text parameters. | ||
| * A generic class to hold either image, text, or audio data for multimodal | ||
| * inputs. This allows the generate() API to take a std::vector of these objects | ||
| * instead of separate image, text, and audio parameters. | ||
| */ | ||
| class ET_EXPERIMENTAL MultimodalInput { | ||
| public: | ||
| enum class Type { TEXT, IMAGE }; | ||
| enum class Type { TEXT, IMAGE, AUDIO, RAW_AUDIO }; | ||
jackzhxng marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
jackzhxng marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| // Constructors | ||
| explicit MultimodalInput(const std::string& text) : data_(text) {} | ||
| explicit MultimodalInput(std::string&& text) : data_(std::move(text)) {} | ||
| explicit MultimodalInput(const Image& image) : data_(image) {} | ||
| explicit MultimodalInput(Image&& image) : data_(std::move(image)) {} | ||
| explicit MultimodalInput(const Audio& audio) : data_(audio) {} | ||
| explicit MultimodalInput(Audio&& audio) : data_(std::move(audio)) {} | ||
| explicit MultimodalInput(const RawAudio& raw_audio) : data_(raw_audio) {} | ||
| explicit MultimodalInput(RawAudio&& raw_audio) | ||
| : data_(std::move(raw_audio)) {} | ||
|
|
||
| // Copy constructor and assignment | ||
| MultimodalInput(const MultimodalInput& other) = default; | ||
|
|
@@ -60,12 +66,35 @@ class ET_EXPERIMENTAL MultimodalInput { | |
| return std::holds_alternative<Image>(data_); | ||
| } | ||
|
|
||
| /** | ||
| * Check if this input contains audio data. | ||
| * @return true if this input contains audio, false otherwise. | ||
| */ | ||
| bool is_audio() const noexcept { | ||
| return std::holds_alternative<Audio>(data_); | ||
| } | ||
|
|
||
| /** | ||
| * Check if this input contains raw audio data. | ||
| * @return true if this input contains raw audio, false otherwise. | ||
| */ | ||
| bool is_raw_audio() const noexcept { | ||
| return std::holds_alternative<RawAudio>(data_); | ||
| } | ||
|
|
||
| /** | ||
| * Get the type of data stored in this input. | ||
| * @return Type::TEXT if text data, Type::IMAGE if image data. | ||
| * @return Type::TEXT if text data, Type::IMAGE if image data, Type::AUDIO if | ||
| * audio data, Type::RAW_AUDIO if raw audio data. | ||
| */ | ||
| Type get_type() const noexcept { | ||
| return is_text() ? Type::TEXT : Type::IMAGE; | ||
| if (is_text()) | ||
| return Type::TEXT; | ||
| if (is_image()) | ||
| return Type::IMAGE; | ||
| if (is_audio()) | ||
| return Type::AUDIO; | ||
| return Type::RAW_AUDIO; | ||
jackzhxng marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| /** | ||
|
|
@@ -122,6 +151,60 @@ class ET_EXPERIMENTAL MultimodalInput { | |
| return std::get<Image>(std::move(data_)); | ||
| } | ||
|
|
||
| /** | ||
| * Get the audio data from this input. | ||
| * @return Reference to the stored Audio object. | ||
| * @throws std::bad_variant_access if this input doesn't contain audio. | ||
| */ | ||
| const Audio& get_audio() const& { | ||
| return std::get<Audio>(data_); | ||
| } | ||
|
|
||
| /** | ||
| * Get the audio data from this input (mutable version). | ||
| * @return Mutable reference to the stored Audio object. | ||
| * @throws std::bad_variant_access if this input doesn't contain audio. | ||
| */ | ||
| Audio& get_audio() & { | ||
| return std::get<Audio>(data_); | ||
| } | ||
|
Comment on lines
+177
to
+179
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this needed? like do we ever return mutable Audio? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I was thinking the same too, this is following already established pattern, I'm thinking of getting rid of all of these get_ variants later |
||
|
|
||
| /** | ||
| * Get the audio data from this input (rvalue version). | ||
| * @return Rvalue reference to the stored Audio object for efficient moves. | ||
| * @throws std::bad_variant_access if this input doesn't contain audio. | ||
| */ | ||
| Audio&& get_audio() && { | ||
| return std::get<Audio>(std::move(data_)); | ||
| } | ||
|
|
||
| /** | ||
| * Get the raw audio data from this input. | ||
| * @return Reference to the stored RawAudio object. | ||
| * @throws std::bad_variant_access if this input doesn't contain raw audio. | ||
| */ | ||
| const RawAudio& get_raw_audio() const& { | ||
| return std::get<RawAudio>(data_); | ||
| } | ||
|
|
||
| /** | ||
| * Get the raw audio data from this input (mutable version). | ||
| * @return Mutable reference to the stored RawAudio object. | ||
| * @throws std::bad_variant_access if this input doesn't contain raw audio. | ||
| */ | ||
| RawAudio& get_raw_audio() & { | ||
| return std::get<RawAudio>(data_); | ||
| } | ||
|
|
||
| /** | ||
| * Get the raw audio data from this input (rvalue version). | ||
| * @return Rvalue reference to the stored RawAudio object for efficient moves. | ||
| * @throws std::bad_variant_access if this input doesn't contain raw audio. | ||
| */ | ||
| RawAudio&& get_raw_audio() && { | ||
| return std::get<RawAudio>(std::move(data_)); | ||
| } | ||
|
|
||
| /** | ||
| * Try to get the text data from this input safely. | ||
| * @return Pointer to the text string if this input contains text, nullptr | ||
|
|
@@ -158,8 +241,44 @@ class ET_EXPERIMENTAL MultimodalInput { | |
| return std::get_if<Image>(&data_); | ||
| } | ||
|
|
||
| /** | ||
| * Try to get the audio data from this input safely. | ||
| * @return Pointer to the Audio object if this input contains audio, | ||
| * nullptr otherwise. | ||
| */ | ||
| const Audio* try_get_audio() const noexcept { | ||
| return std::get_if<Audio>(&data_); | ||
| } | ||
|
|
||
| /** | ||
| * Try to get the audio data from this input safely (mutable version). | ||
| * @return Pointer to the Audio object if this input contains audio, | ||
| * nullptr otherwise. | ||
| */ | ||
| Audio* try_get_audio() noexcept { | ||
| return std::get_if<Audio>(&data_); | ||
| } | ||
|
|
||
| /** | ||
| * Try to get the raw audio data from this input safely. | ||
| * @return Pointer to the RawAudio object if this input contains raw audio, | ||
| * nullptr otherwise. | ||
| */ | ||
| const RawAudio* try_get_raw_audio() const noexcept { | ||
| return std::get_if<RawAudio>(&data_); | ||
| } | ||
|
|
||
| /** | ||
| * Try to get the raw audio data from this input safely (mutable version). | ||
| * @return Pointer to the RawAudio object if this input contains raw audio, | ||
| * nullptr otherwise. | ||
| */ | ||
| RawAudio* try_get_raw_audio() noexcept { | ||
| return std::get_if<RawAudio>(&data_); | ||
| } | ||
|
|
||
| private: | ||
| std::variant<std::string, Image> data_; | ||
| std::variant<std::string, Image, Audio, RawAudio> data_; | ||
| }; | ||
|
|
||
| // Convenience factory functions | ||
|
|
@@ -179,4 +298,21 @@ inline MultimodalInput make_image_input(Image&& image) noexcept { | |
| return MultimodalInput(std::move(image)); | ||
| } | ||
|
|
||
| } // namespace executorch::extension::llm | ||
| inline MultimodalInput make_audio_input(const Audio& audio) noexcept { | ||
| return MultimodalInput(audio); | ||
| } | ||
|
|
||
| inline MultimodalInput make_audio_input(Audio&& audio) noexcept { | ||
| return MultimodalInput(std::move(audio)); | ||
| } | ||
|
|
||
| inline MultimodalInput make_raw_audio_input( | ||
| const RawAudio& raw_audio) noexcept { | ||
| return MultimodalInput(raw_audio); | ||
| } | ||
|
|
||
| inline MultimodalInput make_raw_audio_input(RawAudio&& raw_audio) noexcept { | ||
| return MultimodalInput(std::move(raw_audio)); | ||
| } | ||
|
|
||
| } // namespace executorch::extension::llm | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,7 +37,7 @@ MultimodalPrefiller::MultimodalPrefiller( | |
| Result<uint64_t> MultimodalPrefiller::prefill( | ||
| const MultimodalInput& input, | ||
| int64_t& start_pos) { | ||
| // Check if input is image | ||
| // 1. Run encoder model. | ||
| ::executorch::runtime::EValue encoder_output; | ||
| if (input.is_image()) { | ||
| Image image = input.get_image(); | ||
|
|
@@ -51,33 +51,65 @@ Result<uint64_t> MultimodalPrefiller::prefill( | |
| ET_UNWRAP(module_->execute(kImageEncoderMethod, image_tensor)); | ||
|
|
||
| encoder_output = image_encoder_outputs[0]; | ||
| } else if (input.is_audio()) { | ||
| Audio audio = input.get_audio(); | ||
|
|
||
| // Use the original tensor shape as intended | ||
| auto audio_tensor = executorch::extension::from_blob( | ||
| audio.data.data(), | ||
| {audio.batch_size, audio.n_bins, audio.n_frames}, | ||
| ::executorch::aten::ScalarType::Float); | ||
|
|
||
jackzhxng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| // Run audio encoder | ||
| auto audio_encoder_result = | ||
| module_->execute(kAudioEncoderMethod, audio_tensor); | ||
| if (audio_encoder_result.error() != ::executorch::runtime::Error::Ok) { | ||
| return ::executorch::runtime::Error::Internal; | ||
| } | ||
| auto audio_encoder_outputs = audio_encoder_result.get(); | ||
|
|
||
| encoder_output = audio_encoder_outputs[0]; | ||
| } else if (input.is_text()) { | ||
| // For text input, we don't need to run the image encoder. | ||
| // Instead, we run the text encoder to get the encoder output. | ||
| auto& text = input.get_text(); | ||
| std::vector<uint64_t> tokens = | ||
| ET_UNWRAP_TOKENIZER(tokenizer_->encode(text)); | ||
|
|
||
| auto text_tensor = executorch::extension::from_blob( | ||
| tokens.data(), | ||
| {1, static_cast<aten::SizesType>(tokens.size())}, | ||
| ::executorch::aten::ScalarType::Long); | ||
|
|
||
| // Run token embedding | ||
| // Run text encoder (token embeddings) | ||
| auto token_embedding_outputs = | ||
| ET_UNWRAP(module_->execute(kTokenEmbeddingMethod, text_tensor)); | ||
|
|
||
| encoder_output = token_embedding_outputs[0]; | ||
| } else { | ||
| ET_LOG(Error, "Unsupported input type"); | ||
| // For all other input types (e.g., audio), return error | ||
| // For any other input types, return error | ||
| return ::executorch::runtime::Error::NotSupported; | ||
| } | ||
|
|
||
| auto outputs_res = | ||
| ET_UNWRAP(text_decoder_runner_->decode(encoder_output, start_pos)); | ||
| // 2. Run decoder model for prefill. | ||
| // `cache_position` goes from start_pos to start_pos + encoder_output.size(1). | ||
| // e.g. if start_pos = 2 and encoder_output.size(1) = 5, | ||
| // cache_position_tensor should be [2, 3, 4, 5, 6]. | ||
| int64_t seq_len = encoder_output.toTensor().size(1); | ||
jackzhxng marked this conversation as resolved.
Show resolved
Hide resolved
Comment on lines
+93
to
+97
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Didnt vision based multimodal need exactly the same thing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Vision also takes this path |
||
| std::vector<int64_t> cache_positions(seq_len); | ||
| for (int64_t i = 0; i < seq_len; ++i) { | ||
| cache_positions[i] = start_pos + i; | ||
| } | ||
| auto cache_position_tensor = ::executorch::extension::from_blob( | ||
| cache_positions.data(), {seq_len}, executorch::aten::ScalarType::Long); | ||
| auto prefill_result = module_->execute( | ||
| kTextModelMethod, {cache_position_tensor, encoder_output}); | ||
| if (prefill_result.error() != ::executorch::runtime::Error::Ok) { | ||
| return prefill_result.error(); | ||
| } | ||
| auto prefill_outputs = prefill_result.get(); | ||
jackzhxng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| auto outputs_res = prefill_outputs[0].toTensor(); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. validate if outputs_res.numel() == 0 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why? I think adding so many validations for extremely unlikely outcomes makes things too long and hard to read. I think letting this one naturally error out below and returning that error directly is good enough |
||
|
|
||
| // Update the start_pos, which is only available inside this function. | ||
| // outputs_res can have only one logits. | ||
| // Update start_pos, tracking the current cache position. | ||
jackzhxng marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| start_pos += encoder_output.toTensor().size(1); | ||
|
|
||
| return static_cast<uint64_t>( | ||
|
|
@@ -103,6 +135,11 @@ ::executorch::runtime::Error MultimodalPrefiller::load() { | |
| if (methods.find(kImageEncoderMethod) != methods.end()) { | ||
| ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kImageEncoderMethod)); | ||
| } | ||
|
|
||
| if (methods.find(kAudioEncoderMethod) != methods.end()) { | ||
| ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kAudioEncoderMethod)); | ||
| } | ||
|
|
||
| return ::executorch::runtime::Error::Ok; | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make it backwards compatible...
Keep token_embedding, not token_embeddings
And keep text_model instead of decoder
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's nothing that needs to be kept backwards compatible, this isn't used anywhere atm. I'd like to match this to Optimum