Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/models/llava/export_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,11 @@ def export_all(llava_model: LlavaModel):
{
"image_encoder": image_encoder_ep,
"token_embedding": token_embedding_ep,
"text_model": text_model_ep,
"text_decoder": text_model_ep,
},
partitioner={
"image_encoder": [XnnpackPartitioner()],
"text_model": [
"text_decoder": [
# First partition the DQLinear nodes, then partition the rest of the nodes,
# to avoid multiple DQLinear nodes in the same partition,
# to avoid holding multiple unpacked and packed weight buffers in memory,
Expand All @@ -254,7 +254,7 @@ def export_all(llava_model: LlavaModel):
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
sym_shape_eval_pass={
"image_encoder": ConstraintBasedSymShapeEvalPass(),
"text_model": ConstraintBasedSymShapeEvalPass(),
"text_decoder": ConstraintBasedSymShapeEvalPass(),
"token_embedding": HintBasedSymShapeEvalPass(),
},
)
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llava/runner/llava_text_decoder_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class ET_EXPERIMENTAL LlavaTextDecoderRunner
}

inline static const std::string kTokenEmbeddingMethod = "token_embedding";
inline static const std::string kTextModelMethod = "text_model";
inline static const std::string kTextModelMethod = "text_decoder";
};

} // namespace example
8 changes: 4 additions & 4 deletions examples/models/llava/test/test_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_llava_export(self):
"token_embedding", (prompt_before_image,)
)[0]
llava_module.run_method(
"text_model",
"text_decoder",
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img),
)

Expand All @@ -107,7 +107,7 @@ def test_llava_export(self):
# pte prefill image
pte_embeds_img = llava_module.run_method("image_encoder", (resized,))[0]
llava_module.run_method(
"text_model",
"text_decoder",
(
torch.tensor([start_pos], dtype=torch.int64),
pte_embeds_img,
Expand All @@ -122,7 +122,7 @@ def test_llava_export(self):
"token_embedding", (prompt_after_image,)
)[0]
pte_prefill_after_img = llava_module.run_method(
"text_model",
"text_decoder",
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img),
)[0]

Expand All @@ -139,7 +139,7 @@ def test_llava_export(self):
"token_embedding", (torch.tensor([[new_tokens[i]]], dtype=torch.int64),)
)[0]
logits = llava_module.run_method(
"text_model",
"text_decoder",
(torch.tensor([start_pos + i], dtype=torch.int64), token_embeds),
)[0]
new_tokens.append(torch.argmax(logits).item())
Expand Down
8 changes: 4 additions & 4 deletions examples/models/llava/test/test_pte.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def main():
"token_embedding", (prompt_before_image,)
)[0]
pte_prefill_before_img = llava_module.run_method(
"text_model",
"text_decoder",
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img),
)[0]
print(pte_prefill_before_img)
Expand All @@ -60,7 +60,7 @@ def main():
logging.warning("Image encoder finished")
logging.warning("Image token prefill started")
pte_prefill_img = llava_module.run_method(
"text_model",
"text_decoder",
(
torch.tensor([start_pos], dtype=torch.int64),
pte_embeds_img,
Expand All @@ -77,7 +77,7 @@ def main():
"token_embedding", (prompt_after_image,)
)[0]
pte_prefill_after_img = llava_module.run_method(
"text_model",
"text_decoder",
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img),
)[0]
logging.warning("Text token prefill finished")
Expand All @@ -91,7 +91,7 @@ def main():
"token_embedding", (torch.tensor([[new_tokens[i]]], dtype=torch.int64),)
)[0]
logits = llava_module.run_method(
"text_model",
"text_decoder",
(torch.tensor([start_pos + i], dtype=torch.int64), token_embeds),
)[0]
new_tokens.append(torch.argmax(logits[..., -1, :]).item())
Expand Down
52 changes: 52 additions & 0 deletions extension/llm/runner/audio.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

// A simple audio struct.

#pragma once
#include <executorch/runtime/platform/compiler.h>
#include <cstdint>
#include <vector>

namespace executorch {
namespace extension {
namespace llm {

/**
* Audio inputs as a raw audio tensor, for use when the audio processing
* into a mel spectrogram is baked into the audio encoder with torch.export.
*/
struct ET_EXPERIMENTAL RawAudio {
std::vector<uint8_t> data;
int32_t batch_size;
int32_t n_channels; // For mono, use n_channels = 1.
int32_t n_samples;
};

/**
* Pre-processed audio inputs, ready to feed directly into an audio
* encoder.
*/
struct ET_EXPERIMENTAL Audio {
std::vector<uint8_t> data;
int32_t batch_size;
int32_t n_bins;
int32_t n_frames;
};

} // namespace llm
} // namespace extension
} // namespace executorch

namespace torch {
namespace executor {
// TODO(T197294990): Remove these deprecated aliases once all users have moved
// to the new `::executorch` namespaces.
using ::executorch::extension::llm::Audio;
} // namespace executor
} // namespace torch
3 changes: 2 additions & 1 deletion extension/llm/runner/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ inline constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache";

// Multimodal method name conventions
inline constexpr auto kImageEncoderMethod = "image_encoder";
inline constexpr auto kAudioEncoderMethod = "audio_encoder";
inline constexpr auto kTokenEmbeddingMethod = "token_embedding";
inline constexpr auto kTextModelMethod = "text_model";
inline constexpr auto kTextModelMethod = "text_decoder";

} // namespace executorch::extension::llm
161 changes: 153 additions & 8 deletions extension/llm/runner/multimodal_input.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#pragma once

#include <executorch/extension/llm/runner/audio.h>
#include <executorch/extension/llm/runner/image.h>
#include <executorch/runtime/platform/compiler.h>
#include <string>
Expand All @@ -19,19 +20,31 @@
namespace executorch::extension::llm {

/**
* A generic class to hold either image or text data for multimodal inputs.
* This allows the generate() API to take a std::vector of these objects
* instead of separate image and text parameters.
* A generic class to hold either image, text, or audio data for multimodal
* inputs. This allows the generate() API to take a std::vector of these objects
* instead of separate image, text, and audio parameters.
*/
class ET_EXPERIMENTAL MultimodalInput {
public:
enum class Type { TEXT, IMAGE };
/// Type of multimodal input data
enum class Type {
TEXT, ///< Text string input
IMAGE, ///< Processed image input
AUDIO, ///< Processed audio input
RAW_AUDIO, ///< Raw unprocessed audio input (straight from audio file)
UNSUPPORTED ///< Unsupported input type
};

// Constructors
explicit MultimodalInput(const std::string& text) : data_(text) {}
explicit MultimodalInput(std::string&& text) : data_(std::move(text)) {}
explicit MultimodalInput(const Image& image) : data_(image) {}
explicit MultimodalInput(Image&& image) : data_(std::move(image)) {}
explicit MultimodalInput(const Audio& audio) : data_(audio) {}
explicit MultimodalInput(Audio&& audio) : data_(std::move(audio)) {}
explicit MultimodalInput(const RawAudio& raw_audio) : data_(raw_audio) {}
explicit MultimodalInput(RawAudio&& raw_audio)
: data_(std::move(raw_audio)) {}

// Copy constructor and assignment
MultimodalInput(const MultimodalInput& other) = default;
Expand Down Expand Up @@ -60,12 +73,37 @@ class ET_EXPERIMENTAL MultimodalInput {
return std::holds_alternative<Image>(data_);
}

/**
* Check if this input contains audio data.
* @return true if this input contains audio, false otherwise.
*/
bool is_audio() const noexcept {
return std::holds_alternative<Audio>(data_);
}

/**
* Check if this input contains raw audio data.
* @return true if this input contains raw audio, false otherwise.
*/
bool is_raw_audio() const noexcept {
return std::holds_alternative<RawAudio>(data_);
}

/**
* Get the type of data stored in this input.
* @return Type::TEXT if text data, Type::IMAGE if image data.
* @return Type::TEXT if text data, Type::IMAGE if image data, Type::AUDIO if
* audio data, Type::RAW_AUDIO if raw audio data.
*/
Type get_type() const noexcept {
return is_text() ? Type::TEXT : Type::IMAGE;
if (is_text())
return Type::TEXT;
if (is_image())
return Type::IMAGE;
if (is_audio())
return Type::AUDIO;
if (is_raw_audio())
return Type::RAW_AUDIO;
return Type::UNSUPPORTED;
}

/**
Expand Down Expand Up @@ -122,6 +160,60 @@ class ET_EXPERIMENTAL MultimodalInput {
return std::get<Image>(std::move(data_));
}

/**
* Get the audio data from this input.
* @return Reference to the stored Audio object.
* @throws std::bad_variant_access if this input doesn't contain audio.
*/
const Audio& get_audio() const& {
return std::get<Audio>(data_);
}

/**
* Get the audio data from this input (mutable version).
* @return Mutable reference to the stored Audio object.
* @throws std::bad_variant_access if this input doesn't contain audio.
*/
Audio& get_audio() & {
return std::get<Audio>(data_);
}

/**
* Get the audio data from this input (rvalue version).
* @return Rvalue reference to the stored Audio object for efficient moves.
* @throws std::bad_variant_access if this input doesn't contain audio.
*/
Audio&& get_audio() && {
return std::get<Audio>(std::move(data_));
}

/**
* Get the raw audio data from this input.
* @return Reference to the stored RawAudio object.
* @throws std::bad_variant_access if this input doesn't contain raw audio.
*/
const RawAudio& get_raw_audio() const& {
return std::get<RawAudio>(data_);
}

/**
* Get the raw audio data from this input (mutable version).
* @return Mutable reference to the stored RawAudio object.
* @throws std::bad_variant_access if this input doesn't contain raw audio.
*/
RawAudio& get_raw_audio() & {
return std::get<RawAudio>(data_);
}

/**
* Get the raw audio data from this input (rvalue version).
* @return Rvalue reference to the stored RawAudio object for efficient moves.
* @throws std::bad_variant_access if this input doesn't contain raw audio.
*/
RawAudio&& get_raw_audio() && {
return std::get<RawAudio>(std::move(data_));
}

/**
* Try to get the text data from this input safely.
* @return Pointer to the text string if this input contains text, nullptr
Expand Down Expand Up @@ -158,8 +250,44 @@ class ET_EXPERIMENTAL MultimodalInput {
return std::get_if<Image>(&data_);
}

/**
* Try to get the audio data from this input safely.
* @return Pointer to the Audio object if this input contains audio,
* nullptr otherwise.
*/
const Audio* try_get_audio() const noexcept {
return std::get_if<Audio>(&data_);
}

/**
* Try to get the audio data from this input safely (mutable version).
* @return Pointer to the Audio object if this input contains audio,
* nullptr otherwise.
*/
Audio* try_get_audio() noexcept {
return std::get_if<Audio>(&data_);
}

/**
* Try to get the raw audio data from this input safely.
* @return Pointer to the RawAudio object if this input contains raw audio,
* nullptr otherwise.
*/
const RawAudio* try_get_raw_audio() const noexcept {
return std::get_if<RawAudio>(&data_);
}

/**
* Try to get the raw audio data from this input safely (mutable version).
* @return Pointer to the RawAudio object if this input contains raw audio,
* nullptr otherwise.
*/
RawAudio* try_get_raw_audio() noexcept {
return std::get_if<RawAudio>(&data_);
}

private:
std::variant<std::string, Image> data_;
std::variant<std::string, Image, Audio, RawAudio> data_;
};

// Convenience factory functions
Expand All @@ -179,4 +307,21 @@ inline MultimodalInput make_image_input(Image&& image) noexcept {
return MultimodalInput(std::move(image));
}

} // namespace executorch::extension::llm
inline MultimodalInput make_audio_input(const Audio& audio) noexcept {
return MultimodalInput(audio);
}

inline MultimodalInput make_audio_input(Audio&& audio) noexcept {
return MultimodalInput(std::move(audio));
}

inline MultimodalInput make_raw_audio_input(
const RawAudio& raw_audio) noexcept {
return MultimodalInput(raw_audio);
}

inline MultimodalInput make_raw_audio_input(RawAudio&& raw_audio) noexcept {
return MultimodalInput(std::move(raw_audio));
}

} // namespace executorch::extension::llm
Loading
Loading