diff --git a/extension/llm/runner/test/test_wav_loader.cpp b/extension/llm/runner/test/test_wav_loader.cpp index bc3ac0ff324..0ec19cc2758 100644 --- a/extension/llm/runner/test/test_wav_loader.cpp +++ b/extension/llm/runner/test/test_wav_loader.cpp @@ -19,6 +19,7 @@ using executorch::extension::llm::kOneOverIntMax; using executorch::extension::llm::kOneOverShortMax; +using executorch::extension::llm::kWavFormatIeeeFloat; using executorch::extension::llm::load_wav_audio_data; using executorch::extension::llm::load_wav_header; using executorch::extension::llm::WavHeader; @@ -26,6 +27,10 @@ using executorch::extension::testing::TempFile; namespace { +// WAV file format constants +constexpr uint32_t kWavHeaderSizeBeforeData = 36; +constexpr uint32_t kWavHeaderSizeWithData = 44; + // Test fixture to ensure PAL initialization class WavLoaderTest : public ::testing::Test { protected: @@ -51,20 +56,27 @@ void append_le32(std::vector& out, uint32_t value) { out.push_back(static_cast((value >> 24) & 0xFF)); } +void append_float(std::vector& out, float value) { + const auto* bytes = reinterpret_cast(&value); + for (size_t i = 0; i < sizeof(float); ++i) { + out.push_back(bytes[i]); + } +} + std::vector make_pcm_wav_bytes( int bits_per_sample, const std::vector& samples, uint16_t num_channels = 1, uint32_t sample_rate = 16000) { - const size_t bytes_per_sample = static_cast(bits_per_sample / 8); - const uint32_t subchunk2_size = + const auto bytes_per_sample = static_cast(bits_per_sample / 8); + const auto subchunk2_size = static_cast(samples.size() * bytes_per_sample); const uint32_t byte_rate = sample_rate * num_channels * bytes_per_sample; const uint16_t block_align = num_channels * bytes_per_sample; - const uint32_t chunk_size = 36 + subchunk2_size; + const auto chunk_size = kWavHeaderSizeBeforeData + subchunk2_size; std::vector bytes; - bytes.reserve(44 + subchunk2_size); + bytes.reserve(kWavHeaderSizeWithData + subchunk2_size); append_bytes(bytes, "RIFF"); append_le32(bytes, chunk_size); @@ -91,6 +103,75 @@ std::vector make_pcm_wav_bytes( return bytes; } +std::vector make_float_wav_bytes( + const std::vector& samples, + uint16_t num_channels = 1, + uint32_t sample_rate = 16000) { + const auto bytes_per_sample = sizeof(float); + const auto subchunk2_size = + static_cast(samples.size() * bytes_per_sample); + const uint32_t byte_rate = sample_rate * num_channels * bytes_per_sample; + const uint16_t block_align = num_channels * bytes_per_sample; + const auto chunk_size = kWavHeaderSizeBeforeData + subchunk2_size; + + std::vector bytes; + bytes.reserve(kWavHeaderSizeWithData + subchunk2_size); + + append_bytes(bytes, "RIFF"); + append_le32(bytes, chunk_size); + append_bytes(bytes, "WAVE"); + append_bytes(bytes, "fmt "); + append_le32(bytes, 16); + append_le16(bytes, 3); // AudioFormat IEEE Float + append_le16(bytes, num_channels); + append_le32(bytes, sample_rate); + append_le32(bytes, byte_rate); + append_le16(bytes, block_align); + append_le16(bytes, 32); // bits per sample + append_bytes(bytes, "data"); + append_le32(bytes, subchunk2_size); + + for (float sample : samples) { + append_float(bytes, sample); + } + + return bytes; +} + +std::vector make_wav_bytes_with_format( + uint16_t audio_format, + int bits_per_sample, + const std::vector& sample_data, + uint16_t num_channels = 1, + uint32_t sample_rate = 16000) { + const auto bytes_per_sample = static_cast(bits_per_sample / 8); + const auto subchunk2_size = static_cast(sample_data.size()); + const uint32_t byte_rate = sample_rate * num_channels * bytes_per_sample; + const uint16_t block_align = num_channels * bytes_per_sample; + const auto chunk_size = kWavHeaderSizeBeforeData + subchunk2_size; + + std::vector bytes; + bytes.reserve(kWavHeaderSizeWithData + subchunk2_size); + + append_bytes(bytes, "RIFF"); + append_le32(bytes, chunk_size); + append_bytes(bytes, "WAVE"); + append_bytes(bytes, "fmt "); + append_le32(bytes, 16); + append_le16(bytes, audio_format); + append_le16(bytes, num_channels); + append_le32(bytes, sample_rate); + append_le32(bytes, byte_rate); + append_le16(bytes, block_align); + append_le16(bytes, static_cast(bits_per_sample)); + append_bytes(bytes, "data"); + append_le32(bytes, subchunk2_size); + + bytes.insert(bytes.end(), sample_data.begin(), sample_data.end()); + + return bytes; +} + } // namespace TEST_F(WavLoaderTest, LoadHeaderParsesPcmMetadata) { @@ -153,3 +234,31 @@ TEST_F(WavLoaderTest, LoadHeaderReturnsNullWhenMagicMissing) { std::unique_ptr header = load_wav_header(file.path()); EXPECT_EQ(header, nullptr); } + +TEST_F(WavLoaderTest, LoadAudioDataFloatFormatReadsDirectly) { + const std::vector samples = {0.0f, 0.5f, -0.5f, 1.0f, -1.0f}; + const std::vector wav_bytes = make_float_wav_bytes(samples); + TempFile file(wav_bytes.data(), wav_bytes.size()); + + std::unique_ptr header = load_wav_header(file.path()); + ASSERT_NE(header, nullptr); + EXPECT_EQ(header->AudioFormat, kWavFormatIeeeFloat); + EXPECT_EQ(header->bitsPerSample, 32); + + std::vector audio = load_wav_audio_data(file.path()); + ASSERT_EQ(audio.size(), samples.size()); + + for (size_t i = 0; i < samples.size(); ++i) { + EXPECT_FLOAT_EQ(audio[i], samples[i]); + } +} + +TEST_F(WavLoaderTest, LoadAudioDataRejectsUnsupportedFormat) { + const std::vector sample_data = {0, 0, 0, 0}; + const std::vector wav_bytes = + make_wav_bytes_with_format(0x0006, 16, sample_data); + TempFile file(wav_bytes.data(), wav_bytes.size()); + + EXPECT_DEATH( + { load_wav_audio_data(file.path()); }, "Unsupported audio format"); +} diff --git a/extension/llm/runner/wav_loader.h b/extension/llm/runner/wav_loader.h index f49a4d1723e..3fc43a54392 100644 --- a/extension/llm/runner/wav_loader.h +++ b/extension/llm/runner/wav_loader.h @@ -22,6 +22,9 @@ #include namespace executorch::extension::llm { +// See https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html +constexpr uint16_t kWavFormatPcm = 0x0001; +constexpr uint16_t kWavFormatIeeeFloat = 0x0003; constexpr float kOneOverIntMax = 1 / static_cast(INT32_MAX); constexpr float kOneOverShortMax = 1 / static_cast(INT16_MAX); @@ -168,24 +171,42 @@ inline std::vector load_wav_audio_data(const std::string& fp) { size_t data_offset = header->dataOffset; size_t data_size = header->Subchunk2Size; int bits_per_sample = header->bitsPerSample; + int audio_format = header->AudioFormat; + + if (audio_format != kWavFormatPcm && audio_format != kWavFormatIeeeFloat) { + ET_CHECK_MSG( + false, + "Unsupported audio format: 0x%04X. Only PCM (0x%04X) and IEEE Float (0x%04X) are supported.", + audio_format, + kWavFormatPcm, + kWavFormatIeeeFloat); + } std::vector audio_data; if (bits_per_sample == 32) { size_t num_samples = data_size / 4; - audio_data.resize(num_samples); - const int32_t* input_buffer = - reinterpret_cast(data + data_offset); - for (size_t i = 0; i < num_samples; ++i) { - audio_data[i] = static_cast( - static_cast(input_buffer[i]) * kOneOverIntMax); + if (audio_format == kWavFormatIeeeFloat) { + // IEEE float format - read directly as floats + const float* input_buffer = + reinterpret_cast(data + data_offset); + audio_data.assign(input_buffer, input_buffer + num_samples); + } else { + // PCM integer format - normalize from int32 + const int32_t* input_buffer = + reinterpret_cast(data + data_offset); + audio_data.resize(num_samples); + for (size_t i = 0; i < num_samples; ++i) { + audio_data[i] = static_cast( + static_cast(input_buffer[i]) * kOneOverIntMax); + } } } else if (bits_per_sample == 16) { size_t num_samples = data_size / 2; - audio_data.resize(num_samples); const int16_t* input_buffer = reinterpret_cast(data + data_offset); + audio_data.resize(num_samples); for (size_t i = 0; i < num_samples; ++i) { audio_data[i] = static_cast(