- 
                Notifications
    You must be signed in to change notification settings 
- Fork 698
Add audio to multimodal runner #13662
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
127ff2e
              9520ed2
              9d82591
              4640124
              b9feedf
              7207d1d
              be6eb00
              fb87bbf
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| /* | ||
| * Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| * All rights reserved. | ||
| * | ||
| * This source code is licensed under the BSD-style license found in the | ||
| * LICENSE file in the root directory of this source tree. | ||
| */ | ||
|  | ||
| // A simple audio struct. | ||
|  | ||
| #pragma once | ||
| #include <executorch/runtime/platform/compiler.h> | ||
| #include <cstdint> | ||
| #include <vector> | ||
|  | ||
| namespace executorch { | ||
| namespace extension { | ||
| namespace llm { | ||
|  | ||
| /** | ||
| * Audio inputs as a raw audio tensor, for use when the audio processing | ||
| * into a mel spectrogram is baked into the audio encoder with torch.export. | ||
| */ | ||
| struct ET_EXPERIMENTAL RawAudio { | ||
| std::vector<uint8_t> data; | ||
| int32_t batch_size; | ||
| int32_t n_channels; // For mono, use n_channels = 1. | ||
| int32_t n_samples; | ||
| }; | ||
|  | ||
| /** | ||
| * Audio inputs as a mel spectrogram, ready to feed directly into an audio | ||
|         
                  jackzhxng marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| * encoder. | ||
| */ | ||
| struct ET_EXPERIMENTAL Audio { | ||
| std::vector<uint8_t> data; | ||
| int32_t batch_size; | ||
| int32_t n_bins; | ||
| int32_t n_frames; | ||
| }; | ||
|  | ||
| } // namespace llm | ||
| } // namespace extension | ||
| } // namespace executorch | ||
|  | ||
| namespace torch { | ||
| namespace executor { | ||
| // TODO(T197294990): Remove these deprecated aliases once all users have moved | ||
| // to the new `::executorch` namespaces. | ||
| using ::executorch::extension::llm::Audio; | ||
| } // namespace executor | ||
| } // namespace torch | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -21,7 +21,8 @@ inline constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache"; | |
|  | ||
| // Multimodal method name conventions | ||
| inline constexpr auto kImageEncoderMethod = "image_encoder"; | ||
| inline constexpr auto kTokenEmbeddingMethod = "token_embedding"; | ||
| inline constexpr auto kTextModelMethod = "text_model"; | ||
| inline constexpr auto kAudioEncoderMethod = "audio_encoder"; | ||
| inline constexpr auto kTokenEmbeddingMethod = "token_embeddings"; | ||
| inline constexpr auto kTextModelMethod = "decoder"; | ||
|          | ||
|  | ||
| } // namespace executorch::extension::llm | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -11,6 +11,7 @@ | |
|  | ||
| #pragma once | ||
|  | ||
| #include <executorch/extension/llm/runner/audio.h> | ||
|         
                  jackzhxng marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| #include <executorch/extension/llm/runner/image.h> | ||
| #include <executorch/runtime/platform/compiler.h> | ||
| #include <string> | ||
|  | @@ -19,19 +20,31 @@ | |
| namespace executorch::extension::llm { | ||
|  | ||
| /** | ||
| * A generic class to hold either image or text data for multimodal inputs. | ||
| * This allows the generate() API to take a std::vector of these objects | ||
| * instead of separate image and text parameters. | ||
| * A generic class to hold either image, text, or audio data for multimodal | ||
| * inputs. This allows the generate() API to take a std::vector of these objects | ||
| * instead of separate image, text, and audio parameters. | ||
| */ | ||
| class ET_EXPERIMENTAL MultimodalInput { | ||
| public: | ||
| enum class Type { TEXT, IMAGE }; | ||
| /// Type of multimodal input data | ||
| enum class Type { | ||
| TEXT, ///< Text string input | ||
| IMAGE, ///< Processed image input | ||
| AUDIO, ///< Processed audio input (post-mel spectrogram processing) | ||
|         
                  jackzhxng marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| RAW_AUDIO, ///< Raw unprocessed audio input (straight from audio file) | ||
| UNSUPPORTED ///< Unsupported input type | ||
| }; | ||
|  | ||
| // Constructors | ||
| explicit MultimodalInput(const std::string& text) : data_(text) {} | ||
| explicit MultimodalInput(std::string&& text) : data_(std::move(text)) {} | ||
| explicit MultimodalInput(const Image& image) : data_(image) {} | ||
| explicit MultimodalInput(Image&& image) : data_(std::move(image)) {} | ||
| explicit MultimodalInput(const Audio& audio) : data_(audio) {} | ||
| explicit MultimodalInput(Audio&& audio) : data_(std::move(audio)) {} | ||
| explicit MultimodalInput(const RawAudio& raw_audio) : data_(raw_audio) {} | ||
| explicit MultimodalInput(RawAudio&& raw_audio) | ||
| : data_(std::move(raw_audio)) {} | ||
|  | ||
| // Copy constructor and assignment | ||
| MultimodalInput(const MultimodalInput& other) = default; | ||
|  | @@ -60,12 +73,37 @@ class ET_EXPERIMENTAL MultimodalInput { | |
| return std::holds_alternative<Image>(data_); | ||
| } | ||
|  | ||
| /** | ||
| * Check if this input contains audio data. | ||
| * @return true if this input contains audio, false otherwise. | ||
| */ | ||
| bool is_audio() const noexcept { | ||
| return std::holds_alternative<Audio>(data_); | ||
| } | ||
|  | ||
| /** | ||
| * Check if this input contains raw audio data. | ||
| * @return true if this input contains raw audio, false otherwise. | ||
| */ | ||
| bool is_raw_audio() const noexcept { | ||
| return std::holds_alternative<RawAudio>(data_); | ||
| } | ||
|  | ||
| /** | ||
| * Get the type of data stored in this input. | ||
| * @return Type::TEXT if text data, Type::IMAGE if image data. | ||
| * @return Type::TEXT if text data, Type::IMAGE if image data, Type::AUDIO if | ||
| * audio data, Type::RAW_AUDIO if raw audio data. | ||
| */ | ||
| Type get_type() const noexcept { | ||
| return is_text() ? Type::TEXT : Type::IMAGE; | ||
| if (is_text()) | ||
| return Type::TEXT; | ||
| if (is_image()) | ||
| return Type::IMAGE; | ||
| if (is_audio()) | ||
| return Type::AUDIO; | ||
| if (is_raw_audio()) | ||
| return Type::RAW_AUDIO; | ||
| return Type::UNSUPPORTED; | ||
| } | ||
|  | ||
| /** | ||
|  | @@ -122,6 +160,60 @@ class ET_EXPERIMENTAL MultimodalInput { | |
| return std::get<Image>(std::move(data_)); | ||
| } | ||
|  | ||
| /** | ||
| * Get the audio data from this input. | ||
| * @return Reference to the stored Audio object. | ||
| * @throws std::bad_variant_access if this input doesn't contain audio. | ||
| */ | ||
| const Audio& get_audio() const& { | ||
| return std::get<Audio>(data_); | ||
| } | ||
|  | ||
| /** | ||
| * Get the audio data from this input (mutable version). | ||
| * @return Mutable reference to the stored Audio object. | ||
| * @throws std::bad_variant_access if this input doesn't contain audio. | ||
| */ | ||
| Audio& get_audio() & { | ||
| return std::get<Audio>(data_); | ||
| } | ||
| 
      Comment on lines
    
      +177
     to 
      +179
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this needed? like do we ever return mutable Audio? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I was thinking the same too, this is following already established pattern, I'm thinking of getting rid of all of these get_ variants later | ||
|  | ||
| /** | ||
| * Get the audio data from this input (rvalue version). | ||
| * @return Rvalue reference to the stored Audio object for efficient moves. | ||
| * @throws std::bad_variant_access if this input doesn't contain audio. | ||
| */ | ||
| Audio&& get_audio() && { | ||
| return std::get<Audio>(std::move(data_)); | ||
| } | ||
|  | ||
| /** | ||
| * Get the raw audio data from this input. | ||
| * @return Reference to the stored RawAudio object. | ||
| * @throws std::bad_variant_access if this input doesn't contain raw audio. | ||
| */ | ||
| const RawAudio& get_raw_audio() const& { | ||
| return std::get<RawAudio>(data_); | ||
| } | ||
|  | ||
| /** | ||
| * Get the raw audio data from this input (mutable version). | ||
| * @return Mutable reference to the stored RawAudio object. | ||
| * @throws std::bad_variant_access if this input doesn't contain raw audio. | ||
| */ | ||
| RawAudio& get_raw_audio() & { | ||
| return std::get<RawAudio>(data_); | ||
| } | ||
|  | ||
| /** | ||
| * Get the raw audio data from this input (rvalue version). | ||
| * @return Rvalue reference to the stored RawAudio object for efficient moves. | ||
| * @throws std::bad_variant_access if this input doesn't contain raw audio. | ||
| */ | ||
| RawAudio&& get_raw_audio() && { | ||
| return std::get<RawAudio>(std::move(data_)); | ||
| } | ||
|  | ||
| /** | ||
| * Try to get the text data from this input safely. | ||
| * @return Pointer to the text string if this input contains text, nullptr | ||
|  | @@ -158,8 +250,44 @@ class ET_EXPERIMENTAL MultimodalInput { | |
| return std::get_if<Image>(&data_); | ||
| } | ||
|  | ||
| /** | ||
| * Try to get the audio data from this input safely. | ||
| * @return Pointer to the Audio object if this input contains audio, | ||
| * nullptr otherwise. | ||
| */ | ||
| const Audio* try_get_audio() const noexcept { | ||
| return std::get_if<Audio>(&data_); | ||
| } | ||
|  | ||
| /** | ||
| * Try to get the audio data from this input safely (mutable version). | ||
| * @return Pointer to the Audio object if this input contains audio, | ||
| * nullptr otherwise. | ||
| */ | ||
| Audio* try_get_audio() noexcept { | ||
| return std::get_if<Audio>(&data_); | ||
| } | ||
|  | ||
| /** | ||
| * Try to get the raw audio data from this input safely. | ||
| * @return Pointer to the RawAudio object if this input contains raw audio, | ||
| * nullptr otherwise. | ||
| */ | ||
| const RawAudio* try_get_raw_audio() const noexcept { | ||
| return std::get_if<RawAudio>(&data_); | ||
| } | ||
|  | ||
| /** | ||
| * Try to get the raw audio data from this input safely (mutable version). | ||
| * @return Pointer to the RawAudio object if this input contains raw audio, | ||
| * nullptr otherwise. | ||
| */ | ||
| RawAudio* try_get_raw_audio() noexcept { | ||
| return std::get_if<RawAudio>(&data_); | ||
| } | ||
|  | ||
| private: | ||
| std::variant<std::string, Image> data_; | ||
| std::variant<std::string, Image, Audio, RawAudio> data_; | ||
| }; | ||
|  | ||
| // Convenience factory functions | ||
|  | @@ -179,4 +307,21 @@ inline MultimodalInput make_image_input(Image&& image) noexcept { | |
| return MultimodalInput(std::move(image)); | ||
| } | ||
|  | ||
| } // namespace executorch::extension::llm | ||
| inline MultimodalInput make_audio_input(const Audio& audio) noexcept { | ||
| return MultimodalInput(audio); | ||
| } | ||
|  | ||
| inline MultimodalInput make_audio_input(Audio&& audio) noexcept { | ||
| return MultimodalInput(std::move(audio)); | ||
| } | ||
|  | ||
| inline MultimodalInput make_raw_audio_input( | ||
| const RawAudio& raw_audio) noexcept { | ||
| return MultimodalInput(raw_audio); | ||
| } | ||
|  | ||
| inline MultimodalInput make_raw_audio_input(RawAudio&& raw_audio) noexcept { | ||
| return MultimodalInput(std::move(raw_audio)); | ||
| } | ||
|  | ||
| } // namespace executorch::extension::llm | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -37,7 +37,7 @@ MultimodalPrefiller::MultimodalPrefiller( | |
| Result<uint64_t> MultimodalPrefiller::prefill( | ||
| const MultimodalInput& input, | ||
| int64_t& start_pos) { | ||
| // Check if input is image | ||
| // 1. Run encoder model. | ||
| ::executorch::runtime::EValue encoder_output; | ||
| if (input.is_image()) { | ||
| Image image = input.get_image(); | ||
|  | @@ -51,34 +51,77 @@ Result<uint64_t> MultimodalPrefiller::prefill( | |
| ET_UNWRAP(module_->execute(kImageEncoderMethod, image_tensor)); | ||
|  | ||
| encoder_output = image_encoder_outputs[0]; | ||
| } else if (input.is_audio()) { | ||
| Audio audio = input.get_audio(); | ||
|  | ||
| // Use the original tensor shape as intended | ||
| auto audio_tensor = executorch::extension::from_blob( | ||
| audio.data.data(), | ||
| {audio.batch_size, audio.n_bins, audio.n_frames}, | ||
| ::executorch::aten::ScalarType::Float); | ||
|  | ||
|         
                  jackzhxng marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| // Run audio encoder | ||
| auto audio_encoder_result = | ||
| module_->execute(kAudioEncoderMethod, audio_tensor); | ||
| if (audio_encoder_result.error() != ::executorch::runtime::Error::Ok) { | ||
| return ::executorch::runtime::Error::Internal; | ||
| } | ||
| auto audio_encoder_outputs = audio_encoder_result.get(); | ||
|  | ||
| encoder_output = audio_encoder_outputs[0]; | ||
| } else if (input.is_text()) { | ||
| // For text input, we don't need to run the image encoder. | ||
| // Instead, we run the text encoder to get the encoder output. | ||
| auto& text = input.get_text(); | ||
| std::vector<uint64_t> tokens = | ||
| ET_UNWRAP_TOKENIZER(tokenizer_->encode(text)); | ||
|  | ||
| auto text_tensor = executorch::extension::from_blob( | ||
| tokens.data(), | ||
| {1, static_cast<aten::SizesType>(tokens.size())}, | ||
| ::executorch::aten::ScalarType::Long); | ||
|  | ||
| // Run token embedding | ||
| // Run text encoder (token embeddings) | ||
| auto token_embedding_outputs = | ||
| ET_UNWRAP(module_->execute(kTokenEmbeddingMethod, text_tensor)); | ||
|  | ||
| encoder_output = token_embedding_outputs[0]; | ||
| } else { | ||
| ET_LOG(Error, "Unsupported input type"); | ||
| // For all other input types (e.g., audio), return error | ||
| // For any other input types, return error | ||
| return ::executorch::runtime::Error::NotSupported; | ||
| } | ||
|  | ||
| auto outputs_res = | ||
| ET_UNWRAP(text_decoder_runner_->decode(encoder_output, start_pos)); | ||
| // 2. Run decoder model for prefill. | ||
| // `cache_position` goes from start_pos to start_pos + encoder_output.size(1). | ||
| // e.g. if start_pos = 2 and encoder_output.size(1) = 5, | ||
| // cache_position_tensor should be [2, 3, 4, 5, 6]. | ||
| int64_t seq_len = encoder_output.toTensor().size(1); | ||
|         
                  jackzhxng marked this conversation as resolved.
              Show resolved
            Hide resolved 
      Comment on lines
    
      +93
     to 
      +97
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Didnt vision based multimodal need exactly the same thing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Vision also takes this path | ||
| if (seq_len == 0) { | ||
| ET_LOG(Error, "The encoder returned an empty output."); | ||
| return ::executorch::runtime::Error::InvalidState; | ||
| } | ||
| std::vector<int64_t> cache_positions(seq_len); | ||
| for (int64_t i = 0; i < seq_len; ++i) { | ||
| cache_positions[i] = start_pos + i; | ||
| } | ||
| auto cache_position_tensor = ::executorch::extension::from_blob( | ||
| cache_positions.data(), {seq_len}, executorch::aten::ScalarType::Long); | ||
| auto prefill_result = module_->execute( | ||
| kTextModelMethod, {cache_position_tensor, encoder_output}); | ||
| if (prefill_result.error() != ::executorch::runtime::Error::Ok) { | ||
| return prefill_result.error(); | ||
| } | ||
| // Check if prefill_outputs is empty, if it is return error and log that the | ||
| // specified encoder returned empty results when used to prefill decoder. | ||
| auto prefill_outputs = prefill_result.get(); | ||
|         
                  jackzhxng marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| if (prefill_outputs.empty()) { | ||
| ET_LOG( | ||
| Error, "Encoder returned empty results when used to prefill decoder"); | ||
| return ::executorch::runtime::Error::InvalidState; | ||
| } | ||
| auto outputs_res = prefill_outputs[0].toTensor(); | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. validate if outputs_res.numel() == 0 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why? I think adding so many validations for extremely unlikely outcomes makes things too long and hard to read. I think letting this one naturally error out below and returning that error directly is good enough | ||
|  | ||
| // Update the start_pos, which is only available inside this function. | ||
| // outputs_res can have only one logits. | ||
| start_pos += encoder_output.toTensor().size(1); | ||
| // Update start_pos, tracking the current cache position. | ||
|         
                  jackzhxng marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| start_pos += seq_len; | ||
|  | ||
| return static_cast<uint64_t>( | ||
| text_decoder_runner_->logits_to_token(outputs_res)); | ||
|  | @@ -103,6 +146,11 @@ ::executorch::runtime::Error MultimodalPrefiller::load() { | |
| if (methods.find(kImageEncoderMethod) != methods.end()) { | ||
| ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kImageEncoderMethod)); | ||
| } | ||
|  | ||
| if (methods.find(kAudioEncoderMethod) != methods.end()) { | ||
| ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kAudioEncoderMethod)); | ||
| } | ||
|  | ||
| return ::executorch::runtime::Error::Ok; | ||
| } | ||
|  | ||
|  | ||
Uh oh!
There was an error while loading. Please reload this page.