diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index d95bd7fb054..9ba6a510736 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -226,11 +226,11 @@ def export_all(llava_model: LlavaModel): { "image_encoder": image_encoder_ep, "token_embedding": token_embedding_ep, - "text_model": text_model_ep, + "text_decoder": text_model_ep, }, partitioner={ "image_encoder": [XnnpackPartitioner()], - "text_model": [ + "text_decoder": [ # First partition the DQLinear nodes, then partition the rest of the nodes, # to avoid multiple DQLinear nodes in the same partition, # to avoid holding multiple unpacked and packed weight buffers in memory, @@ -254,7 +254,7 @@ def export_all(llava_model: LlavaModel): memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), sym_shape_eval_pass={ "image_encoder": ConstraintBasedSymShapeEvalPass(), - "text_model": ConstraintBasedSymShapeEvalPass(), + "text_decoder": ConstraintBasedSymShapeEvalPass(), "token_embedding": HintBasedSymShapeEvalPass(), }, ) diff --git a/examples/models/llava/runner/llava_text_decoder_runner.h b/examples/models/llava/runner/llava_text_decoder_runner.h index 09b8e82d49d..cfa92e0c253 100644 --- a/examples/models/llava/runner/llava_text_decoder_runner.h +++ b/examples/models/llava/runner/llava_text_decoder_runner.h @@ -89,7 +89,7 @@ class ET_EXPERIMENTAL LlavaTextDecoderRunner } inline static const std::string kTokenEmbeddingMethod = "token_embedding"; - inline static const std::string kTextModelMethod = "text_model"; + inline static const std::string kTextModelMethod = "text_decoder"; }; } // namespace example diff --git a/examples/models/llava/test/test_llava.py b/examples/models/llava/test/test_llava.py index 05cfd5b1497..def9eaa02bd 100644 --- a/examples/models/llava/test/test_llava.py +++ b/examples/models/llava/test/test_llava.py @@ -96,7 +96,7 @@ def test_llava_export(self): "token_embedding", (prompt_before_image,) )[0] llava_module.run_method( - "text_model", + "text_decoder", (torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img), ) @@ -107,7 +107,7 @@ def test_llava_export(self): # pte prefill image pte_embeds_img = llava_module.run_method("image_encoder", (resized,))[0] llava_module.run_method( - "text_model", + "text_decoder", ( torch.tensor([start_pos], dtype=torch.int64), pte_embeds_img, @@ -122,7 +122,7 @@ def test_llava_export(self): "token_embedding", (prompt_after_image,) )[0] pte_prefill_after_img = llava_module.run_method( - "text_model", + "text_decoder", (torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img), )[0] @@ -139,7 +139,7 @@ def test_llava_export(self): "token_embedding", (torch.tensor([[new_tokens[i]]], dtype=torch.int64),) )[0] logits = llava_module.run_method( - "text_model", + "text_decoder", (torch.tensor([start_pos + i], dtype=torch.int64), token_embeds), )[0] new_tokens.append(torch.argmax(logits).item()) diff --git a/examples/models/llava/test/test_pte.py b/examples/models/llava/test/test_pte.py index f12d72f854b..1f4aaa9938c 100644 --- a/examples/models/llava/test/test_pte.py +++ b/examples/models/llava/test/test_pte.py @@ -47,7 +47,7 @@ def main(): "token_embedding", (prompt_before_image,) )[0] pte_prefill_before_img = llava_module.run_method( - "text_model", + "text_decoder", (torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img), )[0] print(pte_prefill_before_img) @@ -60,7 +60,7 @@ def main(): logging.warning("Image encoder finished") logging.warning("Image token prefill started") pte_prefill_img = llava_module.run_method( - "text_model", + "text_decoder", ( torch.tensor([start_pos], dtype=torch.int64), pte_embeds_img, @@ -77,7 +77,7 @@ def main(): "token_embedding", (prompt_after_image,) )[0] pte_prefill_after_img = llava_module.run_method( - "text_model", + "text_decoder", (torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img), )[0] logging.warning("Text token prefill finished") @@ -91,7 +91,7 @@ def main(): "token_embedding", (torch.tensor([[new_tokens[i]]], dtype=torch.int64),) )[0] logits = llava_module.run_method( - "text_model", + "text_decoder", (torch.tensor([start_pos + i], dtype=torch.int64), token_embeds), )[0] new_tokens.append(torch.argmax(logits[..., -1, :]).item()) diff --git a/examples/models/voxtral/CMakeLists.txt b/examples/models/voxtral/CMakeLists.txt new file mode 100644 index 00000000000..1a5faf3d350 --- /dev/null +++ b/examples/models/voxtral/CMakeLists.txt @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# +# Simple CMake build system for voxtral runner. +# +cmake_minimum_required(VERSION 3.24) +project(voxtral) + +set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) + +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) + +if(CMAKE_TOOLCHAIN_FILE MATCHES ".*(iOS|ios\.toolchain)\.cmake$") + set(CMAKE_TOOLCHAIN_IOS ON) +else() + set(CMAKE_TOOLCHAIN_IOS OFF) +endif() + +# Let files say "include " +set(_common_include_directories ${EXECUTORCH_ROOT}/..) + +# Need this for gflags for some reason +set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/gflags) +find_package(gflags REQUIRED) + +# Find `executorch` libraries, same as for gflags +list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../..) +find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH) +executorch_target_link_options_shared_lib(executorch) + +set(LINK_LIBS executorch gflags) +set(link_libraries ${LINK_LIBS}) +set(_srcs multimodal.cpp) + +list( + APPEND + link_libraries + optimized_native_cpu_ops_lib + quantized_ops_lib + custom_ops + cpublas + eigen_blas +) +executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib) +executorch_target_link_options_shared_lib(quantized_ops_lib) +executorch_target_link_options_shared_lib(custom_ops) + +# XNNPACK +if(TARGET xnnpack_backend) + set(xnnpack_backend_libs xnnpack_backend XNNPACK xnnpack-microkernels-prod) + if(TARGET kleidiai) + list(APPEND xnnpack_backend_libs kleidiai) + endif() + list(APPEND link_libraries ${xnnpack_backend_libs}) + executorch_target_link_options_shared_lib(xnnpack_backend) +endif() + +# Add LLM runner and extension module +if(NOT TARGET extension_llm_runner) + message( + FATAL_ERROR + "ExecuTorch must be installed with EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER enabled." + ) +endif() + +# Needed for cpuinfo where it uses android specific log lib +if(ANDROID) + list(APPEND link_libraries log) +endif() + +# Add the required ExecutorTorch extensions for multimodal LLM runner +list( + APPEND + link_libraries + extension_llm_runner + extension_module + extension_data_loader + extension_tensor + extension_flat_tensor +) + +# Add tokenizers +list(APPEND link_libraries tokenizers::tokenizers) + +add_executable(voxtral_runner ${_srcs}) +if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") + target_link_options_gc_sections(voxtral_runner) + if(NOT APPLE) + target_link_options(voxtral_runner PRIVATE "LINKER:-s") + endif() +endif() + +target_include_directories(voxtral_runner PUBLIC ${_common_include_directories}) +target_link_libraries(voxtral_runner PUBLIC ${link_libraries}) +target_compile_options(voxtral_runner PUBLIC ${_common_compile_options}) diff --git a/examples/models/voxtral/multimodal.cpp b/examples/models/voxtral/multimodal.cpp new file mode 100644 index 00000000000..d7183f3c662 --- /dev/null +++ b/examples/models/voxtral/multimodal.cpp @@ -0,0 +1,217 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +#if defined(ET_USE_THREADPOOL) +#include +#include +#endif + +DEFINE_string( + model_path, + "multimodal.pte", + "Model serialized in flatbuffer format."); + +DEFINE_string(tokenizer_path, "tekken.json", "Tokenizer stuff."); + +DEFINE_string(prompt, "What is happening in this audio?", "Text prompt."); + +DEFINE_string(audio_path, "", "Path to input audio file."); + +DEFINE_double( + temperature, + 0.8f, + "Temperature; Default is 0.8f. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic"); + +DEFINE_int32( + cpu_threads, + -1, + "Number of CPU threads for inference. Defaults to -1, which implies we'll use a heuristic to derive the # of performant cores for a specific device."); + +DEFINE_bool(warmup, false, "Whether to run a warmup run."); + +namespace { + +using ::executorch::extension::llm::Image; +using ::executorch::extension::llm::make_image_input; +using ::executorch::extension::llm::make_text_input; +using ::executorch::extension::llm::MultimodalInput; + +bool ends_with(const std::string& str, const std::string& suffix) { + return str.size() >= suffix.size() && + str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0; +} + +/** + * @brief Loads preprocessed audio data from a binary file + * + * Reads mel spectrogram features that have been pre-computed and saved as a + * binary file. The audio data is expected to be stored as float values in + * binary format, typically saved using: + * with open("tensor.bin", "wb") as f: + * f.write(t.numpy().tobytes()) + * + * @param audio_path Path to the binary audio file (.bin) + * @return MultimodalInput containing the loaded audio data + */ +MultimodalInput loadPreprocessedAudio(const std::string& audio_path) { + std::ifstream f(audio_path, std::ios::binary | std::ios::ate); + int32_t n_bins = 128; + int32_t n_frames = 3000; + std::size_t n_floats = + f.tellg() / sizeof(float); // Number of floats in the audio file. + f.seekg(0, std::ios::beg); + int32_t batch_size = ceil( + n_floats / + (n_bins * n_frames)); // Batch in increments of n_frames, rounding up. + std::vector audio_data(batch_size * n_bins * n_frames); + f.read( + reinterpret_cast(audio_data.data()), + audio_data.size() * sizeof(float)); + + ET_LOG(Info, "audio_data len = %d", audio_data.size()); + + auto audio = std::make_unique<::executorch::extension::llm::Audio>(); + audio->batch_size = batch_size; + audio->n_bins = n_bins; + audio->n_frames = n_frames; + audio->data.resize(audio_data.size() * sizeof(float)); + std::memcpy( + audio->data.data(), audio_data.data(), audio_data.size() * sizeof(float)); + return ::executorch::extension::llm::make_audio_input(std::move(*audio)); +} + +/** + * @brief Processes audio files for multimodal input + * + * Dispatches audio file processing based on file extension: + * - .bin files: Loads preprocessed mel spectrogram features directly + * - .wav/.mp3 files: Currently unsupported, throws runtime_error + * + * This function provides a interface for different audio input formats + * and can be extended to support raw audio processing in the future. + * + * @param audio_path Path to the audio file + * @return MultimodalInput containing the processed audio data + * @throws std::runtime_error if file format is unsupported or processing fails + */ +MultimodalInput processAudioFile(const std::string& audio_path) { + if (ends_with(audio_path, ".bin")) { + // Current behavior - load preprocessed audio stored as a binary file. + return loadPreprocessedAudio(audio_path); + } else if (ends_with(audio_path, ".wav") || ends_with(audio_path, ".mp3")) { + // New: Process raw audio files - unsupported for now + ET_LOG(Error, "Raw audio file processing (.wav/.mp3) is not yet supported"); + throw std::runtime_error("Raw audio file processing not supported"); + } else { + ET_LOG(Error, "Unsupported audio file format: %s", audio_path.c_str()); + throw std::runtime_error("Unsupported audio file format"); + } +} + +} // namespace + +int32_t main(int32_t argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + const char* model_path = FLAGS_model_path.c_str(); + + const char* tokenizer_path = FLAGS_tokenizer_path.c_str(); + const char* prompt = FLAGS_prompt.c_str(); + const char* audio_path = FLAGS_audio_path.c_str(); + float temperature = FLAGS_temperature; + int32_t cpu_threads = FLAGS_cpu_threads; + bool warmup = FLAGS_warmup; + +#if defined(ET_USE_THREADPOOL) + uint32_t num_performant_cores = cpu_threads == -1 + ? ::executorch::extension::cpuinfo::get_num_performant_cores() + : static_cast(cpu_threads); + ET_LOG( + Info, "Resetting threadpool with num threads = %d", num_performant_cores); + if (num_performant_cores > 0) { + ::executorch::extension::threadpool::get_threadpool() + ->_unsafe_reset_threadpool(num_performant_cores); + } +#endif + + // Load tokenizer + std::unique_ptr<::tokenizers::Tokenizer> tokenizer = + ::executorch::extension::llm::load_tokenizer(tokenizer_path); + if (tokenizer == nullptr) { + ET_LOG(Error, "Failed to load tokenizer from: %s", tokenizer_path); + return 1; + } + + // Create multimodal runner + std::unique_ptr<::executorch::extension::llm::MultimodalRunner> runner = + ::executorch::extension::llm::create_multimodal_runner( + model_path, std::move(tokenizer)); + if (runner == nullptr) { + ET_LOG(Error, "Failed to create multimodal runner"); + return 1; + } + + // Load runner + auto load_error = runner->load(); + if (load_error != ::executorch::runtime::Error::Ok) { + ET_LOG(Error, "Failed to load multimodal runner"); + return 1; + } + + // Prepare inputs + std::vector inputs; + + // 1. Add start bos-related text inputs and modality start token. + inputs.emplace_back(make_text_input("[INST][BEGIN_AUDIO]")); + + // 2. Add audio input + inputs.emplace_back(processAudioFile(audio_path)); + + // 3. Add text input (the actual user-submitted prompt) + inputs.emplace_back(make_text_input(std::string(prompt) + "[/INST]")); + + ::executorch::extension::llm::GenerationConfig config; + config.max_new_tokens = 100; + config.temperature = temperature; + + // Run warmup if requested + if (warmup) { + ET_LOG(Info, "Running warmup..."); + auto warmup_error = runner->generate(inputs, config); + if (warmup_error != ::executorch::runtime::Error::Ok) { + ET_LOG(Error, "Failed to run warmup"); + return 1; + } + runner->reset(); + } + + // Generate + ET_LOG(Info, "Starting generation..."); + auto error = runner->generate(inputs, config); + if (error != ::executorch::runtime::Error::Ok) { + ET_LOG(Error, "Failed to generate with multimodal runner"); + return 1; + } + + printf("\n"); + return 0; +} diff --git a/extension/llm/runner/audio.h b/extension/llm/runner/audio.h new file mode 100644 index 00000000000..868765950af --- /dev/null +++ b/extension/llm/runner/audio.h @@ -0,0 +1,52 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// A simple audio struct. + +#pragma once +#include +#include +#include + +namespace executorch { +namespace extension { +namespace llm { + +/** + * Audio inputs as a raw audio tensor, for use when the audio processing + * into a mel spectrogram is baked into the audio encoder with torch.export. + */ +struct ET_EXPERIMENTAL RawAudio { + std::vector data; + int32_t batch_size; + int32_t n_channels; // For mono, use n_channels = 1. + int32_t n_samples; +}; + +/** + * Pre-processed audio inputs, ready to feed directly into an audio + * encoder. + */ +struct ET_EXPERIMENTAL Audio { + std::vector data; + int32_t batch_size; + int32_t n_bins; + int32_t n_frames; +}; + +} // namespace llm +} // namespace extension +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using ::executorch::extension::llm::Audio; +} // namespace executor +} // namespace torch diff --git a/extension/llm/runner/constants.h b/extension/llm/runner/constants.h index fc6ddcb451c..4ba88203c50 100644 --- a/extension/llm/runner/constants.h +++ b/extension/llm/runner/constants.h @@ -21,7 +21,8 @@ inline constexpr auto kUseSDPAWithKVCache = "use_sdpa_with_kv_cache"; // Multimodal method name conventions inline constexpr auto kImageEncoderMethod = "image_encoder"; +inline constexpr auto kAudioEncoderMethod = "audio_encoder"; inline constexpr auto kTokenEmbeddingMethod = "token_embedding"; -inline constexpr auto kTextModelMethod = "text_model"; +inline constexpr auto kTextModelMethod = "text_decoder"; } // namespace executorch::extension::llm diff --git a/extension/llm/runner/llm_runner_helper.cpp b/extension/llm/runner/llm_runner_helper.cpp index ec2e335b7d6..200dba695da 100644 --- a/extension/llm/runner/llm_runner_helper.cpp +++ b/extension/llm/runner/llm_runner_helper.cpp @@ -17,10 +17,12 @@ #include #include #include +#include #include #include #include #include +#include #include namespace executorch::extension::llm { @@ -35,6 +37,18 @@ std::unique_ptr load_tokenizer( size_t bos_token_index, size_t eos_token_index) { runtime::runtime_init(); + auto tekken_tokenizer = std::make_unique(); + // Prevent the case where tekken tokenizer accidentally successfully loads a + // HuggingFace tokenizer, which is also .json. + const std::string tekken_name = "tekken.json"; + if (tokenizer_path.size() >= tekken_name.size() && + tokenizer_path.rfind(tekken_name) == + tokenizer_path.size() - tekken_name.size()) { + if (tekken_tokenizer->load(tokenizer_path) == ::tokenizers::Error::Ok) { + ET_LOG(Info, "Loaded tekken tokenizer"); + return tekken_tokenizer; + } + } auto json_tokenizer = std::make_unique(); if (json_tokenizer->load(tokenizer_path) == ::tokenizers::Error::Ok) { ET_LOG(Info, "Loaded json tokenizer"); @@ -73,9 +87,8 @@ std::unique_ptr load_tokenizer( return nullptr; } -std::unordered_map get_llm_metadata( - tokenizers::Tokenizer* tokenizer, - Module* module) { +::executorch::runtime::Result> +get_llm_metadata(tokenizers::Tokenizer* tokenizer, Module* module) { // Initialize metadata with default values std::unordered_map metadata({ {llm::kEnableDynamicShape, false}, @@ -89,10 +102,20 @@ std::unordered_map get_llm_metadata( auto method_names_result = module->method_names(); if (method_names_result.error() != Error::Ok) { ET_LOG(Error, "Failed reading method names"); - return metadata; + return ::executorch::runtime::Error::InvalidArgument; } const auto& method_names = method_names_result.get(); + // Error out if the max seq len metadata method is not present, since + // it is hard to figure out from just the .pte itself. + if (!method_names.count(llm::kMaxSeqLen)) { + ET_LOG( + Error, + "Required metadata method %s not found in model", + llm::kMaxSeqLen); + return ::executorch::runtime::Error::InvalidArgument; + } + for (auto& pair : metadata) { const auto& method_name = pair.first; auto& value = pair.second; @@ -109,6 +132,18 @@ std::unordered_map get_llm_metadata( } ET_LOG(Info, "Metadata: %s = %" PRId64, method_name.c_str(), value); } + + // If kMaxContextLen method not found but kMaxSeqLen is + // available, set kMaxContextLen to the value of kMaxSeqLen. + if (!method_names.count(llm::kMaxContextLen) && + method_names.count(llm::kMaxSeqLen)) { + metadata[llm::kMaxContextLen] = metadata[llm::kMaxSeqLen]; + ET_LOG( + Info, + "Setting kMaxContextLen to kMaxSeqLen value: %" PRId64, + metadata[llm::kMaxContextLen]); + } + // Set tokenizer-related metadata metadata[llm::kBosId] = tokenizer->bos_tok(); metadata[llm::kVocabSize] = tokenizer->vocab_size(); @@ -165,7 +200,12 @@ std::unique_ptr create_text_llm_runner( // Get metadata from Module ET_LOG(Info, "Reading metadata from model"); - auto metadata = llm::get_llm_metadata(tokenizer.get(), module.get()); + auto metadata_result = llm::get_llm_metadata(tokenizer.get(), module.get()); + if (metadata_result.error() != Error::Ok) { + ET_LOG(Error, "Failed to get metadata from model"); + return nullptr; + } + auto metadata = metadata_result.get(); auto eos_ids = std::make_unique>( llm::get_eos_ids(tokenizer.get(), module.get())); @@ -228,7 +268,12 @@ std::unique_ptr create_multimodal_runner( // Get metadata from Module ET_LOG(Info, "Reading metadata from model"); - auto metadata = get_llm_metadata(tokenizer.get(), module.get()); + auto metadata_result = get_llm_metadata(tokenizer.get(), module.get()); + if (metadata_result.error() != Error::Ok) { + ET_LOG(Error, "Failed to get metadata from model"); + return nullptr; + } + auto metadata = metadata_result.get(); auto eos_ids = std::make_unique>( get_eos_ids(tokenizer.get(), module.get())); diff --git a/extension/llm/runner/llm_runner_helper.h b/extension/llm/runner/llm_runner_helper.h index 5ca96b3bb96..191ea3ab090 100644 --- a/extension/llm/runner/llm_runner_helper.h +++ b/extension/llm/runner/llm_runner_helper.h @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -59,11 +60,13 @@ ET_EXPERIMENTAL std::unique_ptr load_tokenizer( * * @param tokenizer Initialized tokenizer instance * @param module The model module - * @return std::unordered_map Metadata key-value pairs + * @return Result> Metadata key-value + * pairs on success, or Error::InvalidArgument if required metadata (e.g., + * kMaxSeqLen) is missing from the model */ -ET_EXPERIMENTAL std::unordered_map get_llm_metadata( - tokenizers::Tokenizer* tokenizer, - Module* module); +ET_EXPERIMENTAL ::executorch::runtime::Result< + std::unordered_map> +get_llm_metadata(tokenizers::Tokenizer* tokenizer, Module* module); /** * @brief Gets EOS token IDs from the model and tokenizer diff --git a/extension/llm/runner/multimodal_input.h b/extension/llm/runner/multimodal_input.h index ae243992fec..728d8aef08f 100644 --- a/extension/llm/runner/multimodal_input.h +++ b/extension/llm/runner/multimodal_input.h @@ -11,6 +11,7 @@ #pragma once +#include #include #include #include @@ -19,19 +20,31 @@ namespace executorch::extension::llm { /** - * A generic class to hold either image or text data for multimodal inputs. - * This allows the generate() API to take a std::vector of these objects - * instead of separate image and text parameters. + * A generic class to hold either image, text, or audio data for multimodal + * inputs. This allows the generate() API to take a std::vector of these objects + * instead of separate image, text, and audio parameters. */ class ET_EXPERIMENTAL MultimodalInput { public: - enum class Type { TEXT, IMAGE }; + /// Type of multimodal input data + enum class Type { + TEXT, ///< Text string input + IMAGE, ///< Processed image input + AUDIO, ///< Processed audio input + RAW_AUDIO, ///< Raw unprocessed audio input (straight from audio file) + UNSUPPORTED ///< Unsupported input type + }; // Constructors explicit MultimodalInput(const std::string& text) : data_(text) {} explicit MultimodalInput(std::string&& text) : data_(std::move(text)) {} explicit MultimodalInput(const Image& image) : data_(image) {} explicit MultimodalInput(Image&& image) : data_(std::move(image)) {} + explicit MultimodalInput(const Audio& audio) : data_(audio) {} + explicit MultimodalInput(Audio&& audio) : data_(std::move(audio)) {} + explicit MultimodalInput(const RawAudio& raw_audio) : data_(raw_audio) {} + explicit MultimodalInput(RawAudio&& raw_audio) + : data_(std::move(raw_audio)) {} // Copy constructor and assignment MultimodalInput(const MultimodalInput& other) = default; @@ -60,12 +73,37 @@ class ET_EXPERIMENTAL MultimodalInput { return std::holds_alternative(data_); } + /** + * Check if this input contains audio data. + * @return true if this input contains audio, false otherwise. + */ + bool is_audio() const noexcept { + return std::holds_alternative