Skip to content

Commit a5c5f0b

Browse files
committed
Reject other formats
1 parent 822c33c commit a5c5f0b

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed

extension/llm/runner/test/test_wav_loader.cpp

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,40 @@ std::vector<uint8_t> make_float_wav_bytes(
133133
return bytes;
134134
}
135135

136+
std::vector<uint8_t> make_wav_bytes_with_format(
137+
uint16_t audio_format,
138+
int bits_per_sample,
139+
const std::vector<uint8_t>& sample_data,
140+
uint16_t num_channels = 1,
141+
uint32_t sample_rate = 16000) {
142+
const size_t bytes_per_sample = static_cast<size_t>(bits_per_sample / 8);
143+
const uint32_t subchunk2_size = static_cast<uint32_t>(sample_data.size());
144+
const uint32_t byte_rate = sample_rate * num_channels * bytes_per_sample;
145+
const uint16_t block_align = num_channels * bytes_per_sample;
146+
const uint32_t chunk_size = 36 + subchunk2_size;
147+
148+
std::vector<uint8_t> bytes;
149+
bytes.reserve(44 + subchunk2_size);
150+
151+
append_bytes(bytes, "RIFF");
152+
append_le32(bytes, chunk_size);
153+
append_bytes(bytes, "WAVE");
154+
append_bytes(bytes, "fmt ");
155+
append_le32(bytes, 16);
156+
append_le16(bytes, audio_format);
157+
append_le16(bytes, num_channels);
158+
append_le32(bytes, sample_rate);
159+
append_le32(bytes, byte_rate);
160+
append_le16(bytes, block_align);
161+
append_le16(bytes, static_cast<uint16_t>(bits_per_sample));
162+
append_bytes(bytes, "data");
163+
append_le32(bytes, subchunk2_size);
164+
165+
bytes.insert(bytes.end(), sample_data.begin(), sample_data.end());
166+
167+
return bytes;
168+
}
169+
136170
} // namespace
137171

138172
TEST_F(WavLoaderTest, LoadHeaderParsesPcmMetadata) {
@@ -203,7 +237,7 @@ TEST_F(WavLoaderTest, LoadAudioDataFloatFormatReadsDirectly) {
203237

204238
std::unique_ptr<WavHeader> header = load_wav_header(file.path());
205239
ASSERT_NE(header, nullptr);
206-
EXPECT_EQ(header->AudioFormat, 3);
240+
EXPECT_EQ(header->AudioFormat, kWavFormatIeeeFloat);
207241
EXPECT_EQ(header->bitsPerSample, 32);
208242

209243
std::vector<float> audio = load_wav_audio_data(file.path());
@@ -213,3 +247,13 @@ TEST_F(WavLoaderTest, LoadAudioDataFloatFormatReadsDirectly) {
213247
EXPECT_FLOAT_EQ(audio[i], samples[i]);
214248
}
215249
}
250+
251+
TEST_F(WavLoaderTest, LoadAudioDataRejectsUnsupportedFormat) {
252+
const std::vector<uint8_t> sample_data = {0, 0, 0, 0};
253+
const std::vector<uint8_t> wav_bytes =
254+
make_wav_bytes_with_format(0x0006, 16, sample_data);
255+
TempFile file(wav_bytes.data(), wav_bytes.size());
256+
257+
EXPECT_DEATH(
258+
{ load_wav_audio_data(file.path()); }, "Unsupported audio format");
259+
}

extension/llm/runner/wav_loader.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
#include <executorch/runtime/platform/log.h>
2323

2424
namespace executorch::extension::llm {
25+
// See https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html
26+
constexpr uint16_t kWavFormatPcm = 0x0001;
27+
constexpr uint16_t kWavFormatIeeeFloat = 0x0003;
2528

2629
constexpr float kOneOverIntMax = 1 / static_cast<float>(INT32_MAX);
2730
constexpr float kOneOverShortMax = 1 / static_cast<float>(INT16_MAX);
@@ -170,6 +173,15 @@ inline std::vector<float> load_wav_audio_data(const std::string& fp) {
170173
int bits_per_sample = header->bitsPerSample;
171174
int audio_format = header->AudioFormat;
172175

176+
if (audio_format != kWavFormatPcm && audio_format != kWavFormatIeeeFloat) {
177+
ET_CHECK_MSG(
178+
false,
179+
"Unsupported audio format: 0x%04X. Only PCM (0x%04X) and IEEE Float (0x%04X) are supported.",
180+
audio_format,
181+
kWavFormatPcm,
182+
kWavFormatIeeeFloat);
183+
}
184+
173185
std::vector<float> audio_data;
174186

175187
if (bits_per_sample == 32) {

0 commit comments

Comments
 (0)