diff --git a/examples/models/voxtral/README.md b/examples/models/voxtral/README.md index 8cac4264bba..4e9ddcf34a4 100644 --- a/examples/models/voxtral/README.md +++ b/examples/models/voxtral/README.md @@ -41,8 +41,8 @@ To run the model, we will use the Voxtral runner, which utilizes ExecuTorch's Mu The Voxtral runner will do the following things: - Audio Input: - - Option A: Pass the raw audio tensor into exported preprocessor to produce a mel spectrogram tensor. - - Option B: If starting directly with an already processed audio input tensor, format the inputs to the multimodal runner (metadata tokens, audio tokens, text tokens, etc.). + - Option A: Pass raw audio data from a `.wav` file into the exported preprocessor to produce a mel spectrogram tensor. + - Option B: If starting directly with an already processed audio input tensor (preprocessed mel spectrogram), format the inputs to the multimodal runner (metadata tokens, audio tokens, text tokens, etc.). - Feed the formatted inputs to the multimodal modal runner. @@ -66,13 +66,26 @@ cmake -DCMAKE_INSTALL_PREFIX=cmake-out -DBUILD_TESTING=OFF -DCMAKE_BUILD_TYPE=Re ## Running the model You can download the `tekken.json` tokenizer from [Voxtral's HuggingFace repo](https://huggingface.co/mistralai/Voxtral-Mini-3B-2507). + +### Running with raw audio (.wav file) +For raw audio files (`.wav`), you must provide a preprocessor to convert the audio into mel spectrogram format: +``` +./cmake-out/examples/models/voxtral/voxtral_runner \ + --model_path path/to/model.pte \ + --tokenizer_path path/to/tekken.json \ + --prompt "What can you tell me about this audio?" \ + --audio_path path/to/audio_input.wav \ + --processor_path path/to/voxtral_preprocessor.pte +``` + +### Running with preprocessed audio (.bin file) +If you already have a preprocessed mel spectrogram saved as a `.bin` file, you can skip the preprocessor: ``` ./cmake-out/examples/models/voxtral/voxtral_runner \ --model_path path/to/model.pte \ --tokenizer_path path/to/tekken.json \ --prompt "What can you tell me about this audio?" \ - --audio_path path/to/audio_input.bin \ - --processor_path path/to/voxtral_preprocessor.pte # If you're passing raw audio file in audio_path + --audio_path path/to/preprocessed_audio.bin ``` Example output: diff --git a/examples/models/voxtral/multimodal.cpp b/examples/models/voxtral/multimodal.cpp index 081df27cd67..b3dd5e3ab68 100644 --- a/examples/models/voxtral/multimodal.cpp +++ b/examples/models/voxtral/multimodal.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -34,6 +35,7 @@ DEFINE_string( "multimodal.pte", "Model serialized in flatbuffer format."); +DEFINE_string(data_path, "", "Path to data file."); DEFINE_string(tokenizer_path, "tekken.json", "Tokenizer stuff."); DEFINE_string(prompt, "What is happening in this audio?", "Text prompt."); @@ -113,15 +115,15 @@ MultimodalInput loadPreprocessedAudio(const std::string& audio_path) { } /** - * @brief Loads a .bin file into a tensor and processes it using a .pte - * processor + * @brief Loads raw audio from a .bin or .wav file and processes it using a + * .pte processor * - * This function loads raw audio data from a .bin file (similar to - * loadPreprocessedAudio), creates a tensor from it, and then passes it through - * a processor module loaded from a .pte file to generate processed audio - * features. + * This function loads raw audio data from either a .bin file (raw float array) + * or a .wav file (WAV format with headers), creates a tensor from it, and then + * passes it through a processor module loaded from a .pte file to generate + * processed audio features. * - * @param audio_path Path to the .bin audio file + * @param audio_path Path to the .bin or .wav audio file * @param processor_path Path to the .pte processor file * @return MultimodalInput containing the processed audio data * @throws std::runtime_error if file loading or processing fails @@ -135,6 +137,41 @@ MultimodalInput processRawAudioFile( "Processor path is required for raw audio processing"); } + // Load the audio data from file (.bin or .wav) + std::vector audio_data; + if (ends_with(audio_path, ".wav")) { + audio_data = ::executorch::extension::llm::load_wav_audio_data(audio_path); + ET_LOG( + Info, + "Loaded WAV file: %s, %zu samples", + audio_path.c_str(), + audio_data.size()); + } else if (ends_with(audio_path, ".bin")) { + std::ifstream f(audio_path, std::ios::binary | std::ios::ate); + if (!f.is_open()) { + ET_LOG(Error, "Failed to open audio file: %s", audio_path.c_str()); + throw std::runtime_error("Failed to open audio file"); + } + + std::size_t n_floats = f.tellg() / sizeof(float); + f.seekg(0, std::ios::beg); + + audio_data.resize(n_floats); + f.read( + reinterpret_cast(audio_data.data()), + audio_data.size() * sizeof(float)); + f.close(); + + ET_LOG( + Info, "Loaded .bin file: %s, %zu floats", audio_path.c_str(), n_floats); + } else { + ET_LOG( + Error, + "Unsupported audio file format: %s (only .bin and .wav files are supported)", + audio_path.c_str()); + throw std::runtime_error("Unsupported audio file format"); + } + // Load the audio processor .pte. std::unique_ptr processor_module; try { @@ -153,25 +190,6 @@ MultimodalInput processRawAudioFile( throw std::runtime_error("Exception while loading processor module"); } - // Load the audio data from file. - std::ifstream f(audio_path, std::ios::binary | std::ios::ate); - if (!f.is_open()) { - ET_LOG(Error, "Failed to open audio file: %s", audio_path.c_str()); - throw std::runtime_error("Failed to open audio file"); - } - - std::size_t n_floats = f.tellg() / sizeof(float); - f.seekg(0, std::ios::beg); - - std::vector audio_data(n_floats); - f.read( - reinterpret_cast(audio_data.data()), - audio_data.size() * sizeof(float)); - f.close(); - - ET_LOG( - Info, "Loaded .bin file: %s, %zu floats", audio_path.c_str(), n_floats); - // Execute the processor std::vector tensor_shape = { static_cast(audio_data.size())}; @@ -226,33 +244,39 @@ MultimodalInput processRawAudioFile( * * Dispatches audio file processing based on file extension and processor * availability: + * - .wav files: Requires processor, processes raw audio through processor * - .bin files with processor: Loads raw audio from .bin and processes through * processor * - .bin files without processor: Loads preprocessed mel spectrogram features * directly * - * @param audio_path Path to the audio file (.bin) - * @param processor_path Path to the processor .pte file (optional) + * @param audio_path Path to the audio file (.bin or .wav) + * @param processor_path Path to the processor .pte file (optional for .bin, + * required for .wav) * @return MultimodalInput containing the processed audio data * @throws std::runtime_error if file format is unsupported or processing fails */ MultimodalInput processAudioFile( const std::string& audio_path, const std::string& processor_path = "") { - if (ends_with(audio_path, ".bin")) { - if (!processor_path.empty()) { - // Process raw audio from .bin file through the processor - return processRawAudioFile(audio_path, processor_path); - } else { - // Load preprocessed audio stored as a binary file (existing behavior) - return loadPreprocessedAudio(audio_path); + if (ends_with(audio_path, ".wav") || ends_with(audio_path, ".bin")) { + if (processor_path.empty()) { + if (ends_with(audio_path, ".wav")) { + ET_CHECK_MSG( + false, + "Processor path is required for .wav file processing: %s", + audio_path.c_str()); + } else { + // Load preprocessed audio stored as a binary file (existing behavior) + return loadPreprocessedAudio(audio_path); + } } + return processRawAudioFile(audio_path, processor_path); } else { - ET_LOG( - Error, - "Unsupported audio file format: %s (only .bin files are supported)", + ET_CHECK_MSG( + false, + "Unsupported audio file format: %s (only .bin and .wav files are supported)", audio_path.c_str()); - throw std::runtime_error("Unsupported audio file format"); } } @@ -267,6 +291,7 @@ int32_t main(int32_t argc, char** argv) { const char* prompt = FLAGS_prompt.c_str(); const char* audio_path = FLAGS_audio_path.c_str(); const char* processor_path = FLAGS_processor_path.c_str(); + const char* data_path = FLAGS_data_path.c_str(); float temperature = FLAGS_temperature; int32_t cpu_threads = FLAGS_cpu_threads; bool warmup = FLAGS_warmup; @@ -294,7 +319,7 @@ int32_t main(int32_t argc, char** argv) { // Create multimodal runner std::unique_ptr<::executorch::extension::llm::MultimodalRunner> runner = ::executorch::extension::llm::create_multimodal_runner( - model_path, std::move(tokenizer)); + model_path, std::move(tokenizer), data_path); if (runner == nullptr) { ET_LOG(Error, "Failed to create multimodal runner"); return 1; diff --git a/extension/llm/runner/targets.bzl b/extension/llm/runner/targets.bzl index 242860a195a..e001e8fc154 100644 --- a/extension/llm/runner/targets.bzl +++ b/extension/llm/runner/targets.bzl @@ -105,6 +105,7 @@ def define_common_targets(): exported_headers = [ "audio.h", "image.h", + "wav_loader.h", "multimodal_input.h", "multimodal_runner.h", "multimodal_prefiller.h", diff --git a/extension/llm/runner/test/CMakeLists.txt b/extension/llm/runner/test/CMakeLists.txt index 2aa18000831..934a5797da1 100644 --- a/extension/llm/runner/test/CMakeLists.txt +++ b/extension/llm/runner/test/CMakeLists.txt @@ -19,7 +19,7 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake) set(_test_srcs test_generation_config.cpp test_text_llm_runner.cpp test_text_prefiller.cpp - test_text_decoder_runner.cpp test_multimodal_input.cpp + test_text_decoder_runner.cpp test_multimodal_input.cpp test_wav_loader.cpp ) # Add LSan stub for Apple platforms diff --git a/extension/llm/runner/test/targets.bzl b/extension/llm/runner/test/targets.bzl index 3339b3b8584..0571b39ccdb 100644 --- a/extension/llm/runner/test/targets.bzl +++ b/extension/llm/runner/test/targets.bzl @@ -44,3 +44,13 @@ def define_common_targets(): "//executorch/extension/llm/runner:multimodal_runner_lib", ], ) + + runtime.cxx_test( + name = "test_wav_loader", + srcs = ["test_wav_loader.cpp"], + deps = [ + "//executorch/extension/testing_util:temp_file", + "//executorch/extension/llm/runner:multimodal_runner_lib", + "//executorch/runtime/platform:platform", + ], + ) diff --git a/extension/llm/runner/test/test_wav_loader.cpp b/extension/llm/runner/test/test_wav_loader.cpp new file mode 100644 index 00000000000..bc3ac0ff324 --- /dev/null +++ b/extension/llm/runner/test/test_wav_loader.cpp @@ -0,0 +1,155 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include +#include +#include +#include + +#include + +using executorch::extension::llm::kOneOverIntMax; +using executorch::extension::llm::kOneOverShortMax; +using executorch::extension::llm::load_wav_audio_data; +using executorch::extension::llm::load_wav_header; +using executorch::extension::llm::WavHeader; +using executorch::extension::testing::TempFile; + +namespace { + +// Test fixture to ensure PAL initialization +class WavLoaderTest : public ::testing::Test { + protected: + void SetUp() override { + // Ensure PAL is initialized before tests run + executorch::runtime::runtime_init(); + } +}; + +void append_bytes(std::vector& out, const char* literal) { + out.insert(out.end(), literal, literal + 4); +} + +void append_le16(std::vector& out, uint16_t value) { + out.push_back(static_cast(value & 0xFF)); + out.push_back(static_cast((value >> 8) & 0xFF)); +} + +void append_le32(std::vector& out, uint32_t value) { + out.push_back(static_cast(value & 0xFF)); + out.push_back(static_cast((value >> 8) & 0xFF)); + out.push_back(static_cast((value >> 16) & 0xFF)); + out.push_back(static_cast((value >> 24) & 0xFF)); +} + +std::vector make_pcm_wav_bytes( + int bits_per_sample, + const std::vector& samples, + uint16_t num_channels = 1, + uint32_t sample_rate = 16000) { + const size_t bytes_per_sample = static_cast(bits_per_sample / 8); + const uint32_t subchunk2_size = + static_cast(samples.size() * bytes_per_sample); + const uint32_t byte_rate = sample_rate * num_channels * bytes_per_sample; + const uint16_t block_align = num_channels * bytes_per_sample; + const uint32_t chunk_size = 36 + subchunk2_size; + + std::vector bytes; + bytes.reserve(44 + subchunk2_size); + + append_bytes(bytes, "RIFF"); + append_le32(bytes, chunk_size); + append_bytes(bytes, "WAVE"); + append_bytes(bytes, "fmt "); + append_le32(bytes, 16); // PCM + append_le16(bytes, 1); // AudioFormat PCM + append_le16(bytes, num_channels); + append_le32(bytes, sample_rate); + append_le32(bytes, byte_rate); + append_le16(bytes, block_align); + append_le16(bytes, static_cast(bits_per_sample)); + append_bytes(bytes, "data"); + append_le32(bytes, subchunk2_size); + + for (int32_t sample : samples) { + const uint32_t encoded = + static_cast(static_cast(sample)); + for (size_t byte_idx = 0; byte_idx < bytes_per_sample; ++byte_idx) { + bytes.push_back(static_cast((encoded >> (8 * byte_idx)) & 0xFF)); + } + } + + return bytes; +} + +} // namespace + +TEST_F(WavLoaderTest, LoadHeaderParsesPcmMetadata) { + const std::vector wav_bytes = + make_pcm_wav_bytes(16, {0, 32767, -32768}); + TempFile file(wav_bytes.data(), wav_bytes.size()); + + std::unique_ptr header = load_wav_header(file.path()); + ASSERT_NE(header, nullptr); + + EXPECT_EQ(header->AudioFormat, 1); + EXPECT_EQ(header->NumOfChan, 1); + EXPECT_EQ(header->SamplesPerSec, 16000); + EXPECT_EQ(header->bitsPerSample, 16); + EXPECT_EQ(header->blockAlign, 2); + EXPECT_EQ(header->bytesPerSec, 32000); + EXPECT_EQ(header->dataOffset, 44); + EXPECT_EQ(header->Subchunk2Size, 6); +} + +TEST_F(WavLoaderTest, LoadAudioData16BitNormalizesSamples) { + const std::vector samples = {0, 32767, -32768}; + const std::vector wav_bytes = make_pcm_wav_bytes(16, samples); + TempFile file(wav_bytes.data(), wav_bytes.size()); + + std::vector audio = load_wav_audio_data(file.path()); + ASSERT_EQ(audio.size(), samples.size()); + + EXPECT_NEAR(audio[0], 0.0f, 1e-6f); + EXPECT_NEAR(audio[1], 32767.0f * kOneOverShortMax, 1e-6f); + EXPECT_NEAR(audio[2], -32768.0f * kOneOverShortMax, 1e-6f); +} + +TEST_F(WavLoaderTest, LoadAudioData32BitNormalizesSamples) { + const std::vector samples = { + 0, + std::numeric_limits::max(), + std::numeric_limits::min()}; + const std::vector wav_bytes = make_pcm_wav_bytes(32, samples); + TempFile file(wav_bytes.data(), wav_bytes.size()); + + std::vector audio = load_wav_audio_data(file.path()); + ASSERT_EQ(audio.size(), samples.size()); + + EXPECT_NEAR(audio[0], 0.0f, 1e-8f); + EXPECT_NEAR( + audio[1], + static_cast(static_cast(samples[1]) * kOneOverIntMax), + 1e-6f); + EXPECT_NEAR( + audio[2], + static_cast(static_cast(samples[2]) * kOneOverIntMax), + 1e-6f); +} + +TEST_F(WavLoaderTest, LoadHeaderReturnsNullWhenMagicMissing) { + const std::string bogus_contents = "not a wav file"; + TempFile file(bogus_contents); + + std::unique_ptr header = load_wav_header(file.path()); + EXPECT_EQ(header, nullptr); +} diff --git a/extension/llm/runner/wav_loader.h b/extension/llm/runner/wav_loader.h new file mode 100644 index 00000000000..f49a4d1723e --- /dev/null +++ b/extension/llm/runner/wav_loader.h @@ -0,0 +1,210 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// A simple WAV file loader. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace executorch::extension::llm { + +constexpr float kOneOverIntMax = 1 / static_cast(INT32_MAX); +constexpr float kOneOverShortMax = 1 / static_cast(INT16_MAX); + +struct WavHeader { + /* RIFF Chunk Descriptor */ + uint8_t RIFF[4]; + uint32_t ChunkSize; + uint8_t WAVE[4]; + /* "fmt" sub-chunk */ + uint8_t fmt[4]; + uint32_t Subchunk1Size; + uint16_t AudioFormat; + uint16_t NumOfChan; + uint32_t SamplesPerSec; + uint32_t bytesPerSec; + uint16_t blockAlign; + uint16_t bitsPerSample; + /* "data" sub-chunk */ + uint32_t dataOffset; + uint32_t Subchunk2Size; +}; + +inline std::unique_ptr load_wav_header(const std::string& fp) { + std::ifstream file(fp, std::ios::binary); + if (!file.is_open()) { + ET_CHECK_MSG(false, "Failed to open WAV file: %s", fp.c_str()); + } + + file.seekg(0, std::ios::end); + size_t file_size = file.tellg(); + file.seekg(0, std::ios::beg); + + std::vector buffer(file_size); + file.read(buffer.data(), file_size); + file.close(); + + const char* data = buffer.data(); + size_t data_size = buffer.size(); + + bool has_riff = false; + bool has_wave = false; + + if (data_size >= 4 && std::memcmp(data, "RIFF", 4) == 0) { + has_riff = true; + } + + if (data_size >= 12 && std::memcmp(data + 8, "WAVE", 4) == 0) { + has_wave = true; + } + + bool is_wav_file = has_riff && has_wave; + std::unique_ptr header; + + if (is_wav_file) { + header = std::make_unique(); + size_t default_header_size = sizeof(WavHeader); + + size_t data_offset = 0; + for (size_t i = 0; i + 4 < data_size; i++) { + if (std::memcmp(data + i, "data", 4) == 0) { + data_offset = i; + break; + } + } + + if (data_size >= default_header_size) { + std::memcpy( + reinterpret_cast(header.get()), data, default_header_size); + + ET_LOG(Info, "WAV header detected, getting raw audio data."); + ET_LOG( + Info, + "RIFF Header: %c%c%c%c", + header->RIFF[0], + header->RIFF[1], + header->RIFF[2], + header->RIFF[3]); + ET_LOG(Info, "Chunk Size: %d", header->ChunkSize); + ET_LOG( + Info, + "WAVE Header: %c%c%c%c", + header->WAVE[0], + header->WAVE[1], + header->WAVE[2], + header->WAVE[3]); + ET_LOG( + Info, + "Format Header: %c%c%c%c", + header->fmt[0], + header->fmt[1], + header->fmt[2], + header->fmt[3]); + ET_LOG(Info, "Format Chunk Size: %d", header->Subchunk1Size); + ET_LOG(Info, "Audio Format: %d", header->AudioFormat); + ET_LOG(Info, "Number of Channels: %d", header->NumOfChan); + ET_LOG(Info, "Sample Rate: %d", header->SamplesPerSec); + ET_LOG(Info, "Byte Rate: %d", header->bytesPerSec); + ET_LOG(Info, "Block Align: %d", header->blockAlign); + ET_LOG(Info, "Bits per Sample: %d", header->bitsPerSample); + + if (data_offset != 0) { + header->Subchunk2Size = + *reinterpret_cast(data + data_offset + 4); + ET_LOG(Info, "Subchunk2Size: %d", header->Subchunk2Size); + header->dataOffset = static_cast(data_offset + 8); + } else { + ET_LOG( + Error, + "WAV file structure is invalid, missing Subchunk2ID 'data' field."); + throw std::runtime_error("Invalid WAV file structure"); + } + } else { + ET_CHECK_MSG( + false, + "WAV header detected but file is too small to contain a complete header"); + } + } + + return header; +} + +inline std::vector load_wav_audio_data(const std::string& fp) { + std::ifstream file(fp, std::ios::binary); + if (!file.is_open()) { + ET_CHECK_MSG(false, "Failed to open WAV file: %s", fp.c_str()); + } + + file.seekg(0, std::ios::end); + size_t file_size = file.tellg(); + file.seekg(0, std::ios::beg); + + std::vector buffer(file_size); + file.read(buffer.data(), file_size); + file.close(); + + auto header = load_wav_header(fp); + + if (header.get() == nullptr) { + ET_CHECK_MSG(false, "WAV header not detected in file: %s", fp.c_str()); + } + + const char* data = buffer.data(); + size_t data_offset = header->dataOffset; + size_t data_size = header->Subchunk2Size; + int bits_per_sample = header->bitsPerSample; + + std::vector audio_data; + + if (bits_per_sample == 32) { + size_t num_samples = data_size / 4; + audio_data.resize(num_samples); + const int32_t* input_buffer = + reinterpret_cast(data + data_offset); + + for (size_t i = 0; i < num_samples; ++i) { + audio_data[i] = static_cast( + static_cast(input_buffer[i]) * kOneOverIntMax); + } + } else if (bits_per_sample == 16) { + size_t num_samples = data_size / 2; + audio_data.resize(num_samples); + const int16_t* input_buffer = + reinterpret_cast(data + data_offset); + + for (size_t i = 0; i < num_samples; ++i) { + audio_data[i] = static_cast( + static_cast(input_buffer[i]) * kOneOverShortMax); + } + } else { + ET_CHECK_MSG( + false, + "Unsupported bits per sample: %d. Only support 32 and 16.", + bits_per_sample); + } + + ET_LOG( + Info, + "Loaded %zu audio samples from WAV file: %s", + audio_data.size(), + fp.c_str()); + + return audio_data; +} + +} // namespace executorch::extension::llm diff --git a/extension/testing_util/targets.bzl b/extension/testing_util/targets.bzl index 05b825645e8..a5ad1fb9b8c 100644 --- a/extension/testing_util/targets.bzl +++ b/extension/testing_util/targets.bzl @@ -14,6 +14,7 @@ def define_common_targets(): visibility = [ "//executorch/devtools/etdump/tests/...", "//executorch/extension/data_loader/test/...", + "//executorch/extension/llm/runner/test/...", "//executorch/extension/testing_util/test/...", "//executorch/extension/fb/ptez/decompression_methods/test/...", "//executorch/extension/fb/ptez/test/...",