Skip to content

Commit 822c33c

Browse files
committed
[llm] Beef up wav loader to read audio format 3 (float format)
This PR adds test coverage for WAV files using audio format 3 (IEEE float format), which allows direct reading of float values without normalization. The existing `**wav_loader.h**` implementation already supports this format (lines 179-185), but there was no test coverage for this code path. **Test additions in `**test_wav_loader.cpp**`:** 1. Added `append_float()` helper function to serialize float values into byte arrays 2. Added `make_float_wav_bytes()` helper function to generate WAV files with audio format 3 (IEEE float) 3. Added `LoadAudioDataFloatFormatReadsDirectly` test case that verifies: * WAV header correctly identifies audio format as 3 * Float samples are read directly without normalization * Output float values match the input exactly The new test case `LoadAudioDataFloatFormatReadsDirectly` validates that: * A WAV file with audio format 3 (IEEE float, 32-bit) is correctly parsed * The audio format is detected as 3 in the header * Float values [0.0f, 0.5f, -0.5f, 1.0f, -1.0f] are read directly without normalization * All float values match exactly using `EXPECT_FLOAT_EQ` Run the test with: `buck2 test //extension/llm/runner/test:test_wav_loader` The WAV loader already has logic to handle IEEE float format (audio format 3) differently from PCM integer formats, but this code path was not covered by tests. This test ensures the float format path works correctly and prevents regressions.
1 parent efc2be7 commit 822c33c

File tree

2 files changed

+76
-5
lines changed

2 files changed

+76
-5
lines changed

extension/llm/runner/test/test_wav_loader.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ void append_le32(std::vector<uint8_t>& out, uint32_t value) {
5151
out.push_back(static_cast<uint8_t>((value >> 24) & 0xFF));
5252
}
5353

54+
void append_float(std::vector<uint8_t>& out, float value) {
55+
const uint8_t* bytes = reinterpret_cast<const uint8_t*>(&value);
56+
for (size_t i = 0; i < sizeof(float); ++i) {
57+
out.push_back(bytes[i]);
58+
}
59+
}
60+
5461
std::vector<uint8_t> make_pcm_wav_bytes(
5562
int bits_per_sample,
5663
const std::vector<int32_t>& samples,
@@ -91,6 +98,41 @@ std::vector<uint8_t> make_pcm_wav_bytes(
9198
return bytes;
9299
}
93100

101+
std::vector<uint8_t> make_float_wav_bytes(
102+
const std::vector<float>& samples,
103+
uint16_t num_channels = 1,
104+
uint32_t sample_rate = 16000) {
105+
const size_t bytes_per_sample = sizeof(float);
106+
const uint32_t subchunk2_size =
107+
static_cast<uint32_t>(samples.size() * bytes_per_sample);
108+
const uint32_t byte_rate = sample_rate * num_channels * bytes_per_sample;
109+
const uint16_t block_align = num_channels * bytes_per_sample;
110+
const uint32_t chunk_size = 36 + subchunk2_size;
111+
112+
std::vector<uint8_t> bytes;
113+
bytes.reserve(44 + subchunk2_size);
114+
115+
append_bytes(bytes, "RIFF");
116+
append_le32(bytes, chunk_size);
117+
append_bytes(bytes, "WAVE");
118+
append_bytes(bytes, "fmt ");
119+
append_le32(bytes, 16);
120+
append_le16(bytes, 3); // AudioFormat IEEE Float
121+
append_le16(bytes, num_channels);
122+
append_le32(bytes, sample_rate);
123+
append_le32(bytes, byte_rate);
124+
append_le16(bytes, block_align);
125+
append_le16(bytes, 32); // bits per sample
126+
append_bytes(bytes, "data");
127+
append_le32(bytes, subchunk2_size);
128+
129+
for (float sample : samples) {
130+
append_float(bytes, sample);
131+
}
132+
133+
return bytes;
134+
}
135+
94136
} // namespace
95137

96138
TEST_F(WavLoaderTest, LoadHeaderParsesPcmMetadata) {
@@ -153,3 +195,21 @@ TEST_F(WavLoaderTest, LoadHeaderReturnsNullWhenMagicMissing) {
153195
std::unique_ptr<WavHeader> header = load_wav_header(file.path());
154196
EXPECT_EQ(header, nullptr);
155197
}
198+
199+
TEST_F(WavLoaderTest, LoadAudioDataFloatFormatReadsDirectly) {
200+
const std::vector<float> samples = {0.0f, 0.5f, -0.5f, 1.0f, -1.0f};
201+
const std::vector<uint8_t> wav_bytes = make_float_wav_bytes(samples);
202+
TempFile file(wav_bytes.data(), wav_bytes.size());
203+
204+
std::unique_ptr<WavHeader> header = load_wav_header(file.path());
205+
ASSERT_NE(header, nullptr);
206+
EXPECT_EQ(header->AudioFormat, 3);
207+
EXPECT_EQ(header->bitsPerSample, 32);
208+
209+
std::vector<float> audio = load_wav_audio_data(file.path());
210+
ASSERT_EQ(audio.size(), samples.size());
211+
212+
for (size_t i = 0; i < samples.size(); ++i) {
213+
EXPECT_FLOAT_EQ(audio[i], samples[i]);
214+
}
215+
}

extension/llm/runner/wav_loader.h

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,18 +168,29 @@ inline std::vector<float> load_wav_audio_data(const std::string& fp) {
168168
size_t data_offset = header->dataOffset;
169169
size_t data_size = header->Subchunk2Size;
170170
int bits_per_sample = header->bitsPerSample;
171+
int audio_format = header->AudioFormat;
171172

172173
std::vector<float> audio_data;
173174

174175
if (bits_per_sample == 32) {
175176
size_t num_samples = data_size / 4;
176177
audio_data.resize(num_samples);
177-
const int32_t* input_buffer =
178-
reinterpret_cast<const int32_t*>(data + data_offset);
179178

180-
for (size_t i = 0; i < num_samples; ++i) {
181-
audio_data[i] = static_cast<float>(
182-
static_cast<double>(input_buffer[i]) * kOneOverIntMax);
179+
if (audio_format == 3) {
180+
// IEEE float format - read directly as floats
181+
const float* input_buffer =
182+
reinterpret_cast<const float*>(data + data_offset);
183+
for (size_t i = 0; i < num_samples; ++i) {
184+
audio_data[i] = input_buffer[i];
185+
}
186+
} else {
187+
// PCM integer format - normalize from int32
188+
const int32_t* input_buffer =
189+
reinterpret_cast<const int32_t*>(data + data_offset);
190+
for (size_t i = 0; i < num_samples; ++i) {
191+
audio_data[i] = static_cast<float>(
192+
static_cast<double>(input_buffer[i]) * kOneOverIntMax);
193+
}
183194
}
184195
} else if (bits_per_sample == 16) {
185196
size_t num_samples = data_size / 2;

0 commit comments

Comments
 (0)