Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions extension/llm/runner/audio.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

// A simple audio struct.

#pragma once
#include <executorch/runtime/platform/compiler.h>
#include <cstdint>
#include <vector>

namespace executorch {
namespace extension {
namespace llm {

/**
* Audio inputs as a raw audio tensor, for use when the audio processing
* into a mel spectrogram is baked into the audio encoder with torch.export.
*/
struct ET_EXPERIMENTAL RawAudio {
std::vector<uint8_t> data;
int32_t batch_size;
int32_t n_channels; // For mono, use n_channels = 1.
int32_t n_samples;
};

/**
* Audio inputs as a mel spectrogram, ready to feed directly into an audio
* encoder.
*/
struct ET_EXPERIMENTAL Audio {
std::vector<uint8_t> data;
int32_t batch_size;
int32_t n_bins;
int32_t n_frames;
};

} // namespace llm
} // namespace extension
} // namespace executorch

namespace torch {
namespace executor {
// TODO(T197294990): Remove these deprecated aliases once all users have moved
// to the new `::executorch` namespaces.
using ::executorch::extension::llm::Audio;
} // namespace executor
} // namespace torch
5 changes: 3 additions & 2 deletions extension/llm/runner/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ inline constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";

// Multimodal method name conventions
inline constexpr auto kImageEncoderMethod = "image_encoder";
inline constexpr auto kTokenEmbeddingMethod = "token_embedding";
inline constexpr auto kTextModelMethod = "text_model";
inline constexpr auto kAudioEncoderMethod = "audio_encoder";
inline constexpr auto kTokenEmbeddingMethod = "token_embeddings";
inline constexpr auto kTextModelMethod = "decoder";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make it backwards compatible...

Keep token_embedding, not token_embeddings

And keep text_model instead of decoder

Copy link
Contributor Author

@jackzhxng jackzhxng Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's nothing that needs to be kept backwards compatible, this isn't used anywhere atm. I'd like to match this to Optimum


} // namespace executorch::extension::llm
161 changes: 153 additions & 8 deletions extension/llm/runner/multimodal_input.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#pragma once

#include <executorch/extension/llm/runner/audio.h>
#include <executorch/extension/llm/runner/image.h>
#include <executorch/runtime/platform/compiler.h>
#include <string>
Expand All @@ -19,19 +20,31 @@
namespace executorch::extension::llm {

/**
* A generic class to hold either image or text data for multimodal inputs.
* This allows the generate() API to take a std::vector of these objects
* instead of separate image and text parameters.
* A generic class to hold either image, text, or audio data for multimodal
* inputs. This allows the generate() API to take a std::vector of these objects
* instead of separate image, text, and audio parameters.
*/
class ET_EXPERIMENTAL MultimodalInput {
public:
enum class Type { TEXT, IMAGE };
/// Type of multimodal input data
enum class Type {
TEXT, ///< Text string input
IMAGE, ///< Processed image input
AUDIO, ///< Processed audio input (post-mel spectrogram processing)
RAW_AUDIO, ///< Raw unprocessed audio input (straight from audio file)
UNSUPPORTED ///< Unsupported input type
};

// Constructors
explicit MultimodalInput(const std::string& text) : data_(text) {}
explicit MultimodalInput(std::string&& text) : data_(std::move(text)) {}
explicit MultimodalInput(const Image& image) : data_(image) {}
explicit MultimodalInput(Image&& image) : data_(std::move(image)) {}
explicit MultimodalInput(const Audio& audio) : data_(audio) {}
explicit MultimodalInput(Audio&& audio) : data_(std::move(audio)) {}
explicit MultimodalInput(const RawAudio& raw_audio) : data_(raw_audio) {}
explicit MultimodalInput(RawAudio&& raw_audio)
: data_(std::move(raw_audio)) {}

// Copy constructor and assignment
MultimodalInput(const MultimodalInput& other) = default;
Expand Down Expand Up @@ -60,12 +73,37 @@ class ET_EXPERIMENTAL MultimodalInput {
return std::holds_alternative<Image>(data_);
}

/**
* Check if this input contains audio data.
* @return true if this input contains audio, false otherwise.
*/
bool is_audio() const noexcept {
return std::holds_alternative<Audio>(data_);
}

/**
* Check if this input contains raw audio data.
* @return true if this input contains raw audio, false otherwise.
*/
bool is_raw_audio() const noexcept {
return std::holds_alternative<RawAudio>(data_);
}

/**
* Get the type of data stored in this input.
* @return Type::TEXT if text data, Type::IMAGE if image data.
* @return Type::TEXT if text data, Type::IMAGE if image data, Type::AUDIO if
* audio data, Type::RAW_AUDIO if raw audio data.
*/
Type get_type() const noexcept {
return is_text() ? Type::TEXT : Type::IMAGE;
if (is_text())
return Type::TEXT;
if (is_image())
return Type::IMAGE;
if (is_audio())
return Type::AUDIO;
if (is_raw_audio())
return Type::RAW_AUDIO;
return Type::UNSUPPORTED;
}

/**
Expand Down Expand Up @@ -122,6 +160,60 @@ class ET_EXPERIMENTAL MultimodalInput {
return std::get<Image>(std::move(data_));
}

/**
* Get the audio data from this input.
* @return Reference to the stored Audio object.
* @throws std::bad_variant_access if this input doesn't contain audio.
*/
const Audio& get_audio() const& {
return std::get<Audio>(data_);
}

/**
* Get the audio data from this input (mutable version).
* @return Mutable reference to the stored Audio object.
* @throws std::bad_variant_access if this input doesn't contain audio.
*/
Audio& get_audio() & {
return std::get<Audio>(data_);
}
Comment on lines +177 to +179
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed? like do we ever return mutable Audio?

Copy link
Contributor Author

@jackzhxng jackzhxng Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was thinking the same too, this is following already established pattern, I'm thinking of getting rid of all of these get_ variants later


/**
* Get the audio data from this input (rvalue version).
* @return Rvalue reference to the stored Audio object for efficient moves.
* @throws std::bad_variant_access if this input doesn't contain audio.
*/
Audio&& get_audio() && {
return std::get<Audio>(std::move(data_));
}

/**
* Get the raw audio data from this input.
* @return Reference to the stored RawAudio object.
* @throws std::bad_variant_access if this input doesn't contain raw audio.
*/
const RawAudio& get_raw_audio() const& {
return std::get<RawAudio>(data_);
}

/**
* Get the raw audio data from this input (mutable version).
* @return Mutable reference to the stored RawAudio object.
* @throws std::bad_variant_access if this input doesn't contain raw audio.
*/
RawAudio& get_raw_audio() & {
return std::get<RawAudio>(data_);
}

/**
* Get the raw audio data from this input (rvalue version).
* @return Rvalue reference to the stored RawAudio object for efficient moves.
* @throws std::bad_variant_access if this input doesn't contain raw audio.
*/
RawAudio&& get_raw_audio() && {
return std::get<RawAudio>(std::move(data_));
}

/**
* Try to get the text data from this input safely.
* @return Pointer to the text string if this input contains text, nullptr
Expand Down Expand Up @@ -158,8 +250,44 @@ class ET_EXPERIMENTAL MultimodalInput {
return std::get_if<Image>(&data_);
}

/**
* Try to get the audio data from this input safely.
* @return Pointer to the Audio object if this input contains audio,
* nullptr otherwise.
*/
const Audio* try_get_audio() const noexcept {
return std::get_if<Audio>(&data_);
}

/**
* Try to get the audio data from this input safely (mutable version).
* @return Pointer to the Audio object if this input contains audio,
* nullptr otherwise.
*/
Audio* try_get_audio() noexcept {
return std::get_if<Audio>(&data_);
}

/**
* Try to get the raw audio data from this input safely.
* @return Pointer to the RawAudio object if this input contains raw audio,
* nullptr otherwise.
*/
const RawAudio* try_get_raw_audio() const noexcept {
return std::get_if<RawAudio>(&data_);
}

/**
* Try to get the raw audio data from this input safely (mutable version).
* @return Pointer to the RawAudio object if this input contains raw audio,
* nullptr otherwise.
*/
RawAudio* try_get_raw_audio() noexcept {
return std::get_if<RawAudio>(&data_);
}

private:
std::variant<std::string, Image> data_;
std::variant<std::string, Image, Audio, RawAudio> data_;
};

// Convenience factory functions
Expand All @@ -179,4 +307,21 @@ inline MultimodalInput make_image_input(Image&& image) noexcept {
return MultimodalInput(std::move(image));
}

} // namespace executorch::extension::llm
inline MultimodalInput make_audio_input(const Audio& audio) noexcept {
return MultimodalInput(audio);
}

inline MultimodalInput make_audio_input(Audio&& audio) noexcept {
return MultimodalInput(std::move(audio));
}

inline MultimodalInput make_raw_audio_input(
const RawAudio& raw_audio) noexcept {
return MultimodalInput(raw_audio);
}

inline MultimodalInput make_raw_audio_input(RawAudio&& raw_audio) noexcept {
return MultimodalInput(std::move(raw_audio));
}

} // namespace executorch::extension::llm
68 changes: 58 additions & 10 deletions extension/llm/runner/multimodal_prefiller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ MultimodalPrefiller::MultimodalPrefiller(
Result<uint64_t> MultimodalPrefiller::prefill(
const MultimodalInput& input,
int64_t& start_pos) {
// Check if input is image
// 1. Run encoder model.
::executorch::runtime::EValue encoder_output;
if (input.is_image()) {
Image image = input.get_image();
Expand All @@ -51,34 +51,77 @@ Result<uint64_t> MultimodalPrefiller::prefill(
ET_UNWRAP(module_->execute(kImageEncoderMethod, image_tensor));

encoder_output = image_encoder_outputs[0];
} else if (input.is_audio()) {
Audio audio = input.get_audio();

// Use the original tensor shape as intended
auto audio_tensor = executorch::extension::from_blob(
audio.data.data(),
{audio.batch_size, audio.n_bins, audio.n_frames},
::executorch::aten::ScalarType::Float);

// Run audio encoder
auto audio_encoder_result =
module_->execute(kAudioEncoderMethod, audio_tensor);
if (audio_encoder_result.error() != ::executorch::runtime::Error::Ok) {
return ::executorch::runtime::Error::Internal;
}
auto audio_encoder_outputs = audio_encoder_result.get();

encoder_output = audio_encoder_outputs[0];
} else if (input.is_text()) {
// For text input, we don't need to run the image encoder.
// Instead, we run the text encoder to get the encoder output.
auto& text = input.get_text();
std::vector<uint64_t> tokens =
ET_UNWRAP_TOKENIZER(tokenizer_->encode(text));

auto text_tensor = executorch::extension::from_blob(
tokens.data(),
{1, static_cast<aten::SizesType>(tokens.size())},
::executorch::aten::ScalarType::Long);

// Run token embedding
// Run text encoder (token embeddings)
auto token_embedding_outputs =
ET_UNWRAP(module_->execute(kTokenEmbeddingMethod, text_tensor));

encoder_output = token_embedding_outputs[0];
} else {
ET_LOG(Error, "Unsupported input type");
// For all other input types (e.g., audio), return error
// For any other input types, return error
return ::executorch::runtime::Error::NotSupported;
}

auto outputs_res =
ET_UNWRAP(text_decoder_runner_->decode(encoder_output, start_pos));
// 2. Run decoder model for prefill.
// `cache_position` goes from start_pos to start_pos + encoder_output.size(1).
// e.g. if start_pos = 2 and encoder_output.size(1) = 5,
// cache_position_tensor should be [2, 3, 4, 5, 6].
int64_t seq_len = encoder_output.toTensor().size(1);
Comment on lines +93 to +97
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didnt vision based multimodal need exactly the same thing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Vision also takes this path

if (seq_len == 0) {
ET_LOG(Error, "The encoder returned an empty output.");
return ::executorch::runtime::Error::InvalidState;
}
std::vector<int64_t> cache_positions(seq_len);
for (int64_t i = 0; i < seq_len; ++i) {
cache_positions[i] = start_pos + i;
}
auto cache_position_tensor = ::executorch::extension::from_blob(
cache_positions.data(), {seq_len}, executorch::aten::ScalarType::Long);
auto prefill_result = module_->execute(
kTextModelMethod, {cache_position_tensor, encoder_output});
if (prefill_result.error() != ::executorch::runtime::Error::Ok) {
return prefill_result.error();
}
// Check if prefill_outputs is empty, if it is return error and log that the
// specified encoder returned empty results when used to prefill decoder.
auto prefill_outputs = prefill_result.get();
if (prefill_outputs.empty()) {
ET_LOG(
Error, "Encoder returned empty results when used to prefill decoder");
return ::executorch::runtime::Error::InvalidState;
}
auto outputs_res = prefill_outputs[0].toTensor();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validate if outputs_res.numel() == 0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? I think adding so many validations for extremely unlikely outcomes makes things too long and hard to read. I think letting this one naturally error out below and returning that error directly is good enough


// Update the start_pos, which is only available inside this function.
// outputs_res can have only one logits.
start_pos += encoder_output.toTensor().size(1);
// Update start_pos, tracking the current cache position.
start_pos += seq_len;

return static_cast<uint64_t>(
text_decoder_runner_->logits_to_token(outputs_res));
Expand All @@ -103,6 +146,11 @@ ::executorch::runtime::Error MultimodalPrefiller::load() {
if (methods.find(kImageEncoderMethod) != methods.end()) {
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kImageEncoderMethod));
}

if (methods.find(kAudioEncoderMethod) != methods.end()) {
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kAudioEncoderMethod));
}

return ::executorch::runtime::Error::Ok;
}

Expand Down
Loading