Skip to content

Commit 3bbb26b

Browse files
committed
clean up audio_helpers
1 parent cf4f5d2 commit 3bbb26b

File tree

4 files changed

+42
-133
lines changed

4 files changed

+42
-133
lines changed

tools/mtmd/mtmd-audio.cpp

Lines changed: 32 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#define _USE_MATH_DEFINES // for M_PI
2323
#include <cmath>
2424
#include <cstdint>
25+
#include <cstring>
2526
#include <thread>
2627
#include <vector>
2728
#include <fstream>
@@ -301,7 +302,7 @@ bool preprocess_audio(
301302
size_t n_samples,
302303
whisper_filters & filters,
303304
whisper_mel & output) {
304-
305+
305306
// a bit hacky, but we want to align the output to a multiple of WHISPER_N_FFT * proj_stack_factor
306307
// proj_stack_factor is 8, specifically for Ultravox (so this is a temporary solution)
307308

@@ -325,144 +326,51 @@ bool preprocess_audio(
325326
} // namespace whisper_preprocessor
326327

327328

328-
namespace wav_utils {
329+
namespace audio_helpers {
329330

330-
bool is_wav_buffer(const std::string buf) {
331-
// RIFF ref: https://en.wikipedia.org/wiki/Resource_Interchange_File_Format
332-
// WAV ref: https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html
333-
if (buf.size() < 12 || buf.substr(0, 4) != "RIFF" || buf.substr(8, 4) != "WAVE") {
331+
bool is_audio_file(const char * buf, size_t len) {
332+
if (len < 12) {
334333
return false;
335334
}
336335

337-
// uint32_t chunk_size = *reinterpret_cast<const uint32_t*>(buf.data() + 4);
338-
// if (chunk_size + 8 != buf.size()) {
339-
// return false;
340-
// }
341-
342-
return true;
336+
// RIFF ref: https://en.wikipedia.org/wiki/Resource_Interchange_File_Format
337+
// WAV ref: https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html
338+
bool is_wav = memcmp(buf, "RIFF", 4) == 0 && memcmp(buf + 8, "WAVE", 4) == 0;
339+
bool is_mp3 = len >= 3 && (
340+
memcmp(buf, "ID3", 3) == 0 ||
341+
// Check for MPEG sync word (simplified check)
342+
((unsigned char)buf[0] == 0xFF && ((unsigned char)buf[1] & 0xE0) == 0xE0)
343+
);
344+
bool is_flac = memcmp(buf, "fLaC", 4) == 0;
345+
346+
return is_wav || is_mp3 || is_flac;
343347
}
344348

345-
// returns true if the buffer is a valid WAV file
346-
bool read_wav_from_buf(const unsigned char * buf_in, size_t len, int target_sampler_rate, std::vector<float> & pcmf32_mono) {
349+
// returns true if the buffer is a valid audio file
350+
bool decode_audio_from_buf(const unsigned char * buf_in, size_t len, int target_sampler_rate, std::vector<float> & pcmf32_mono) {
347351
ma_result result;
348-
// Request f32 output from the decoder. Channel count and sample rate are determined from the file.
349-
ma_decoder_config decoder_config = ma_decoder_config_init(ma_format_f32, 0, 0);
352+
const int channels = 1;
353+
ma_decoder_config decoder_config = ma_decoder_config_init(ma_format_f32, channels, target_sampler_rate);
350354
ma_decoder decoder;
351355

352356
result = ma_decoder_init_memory(buf_in, len, &decoder_config, &decoder);
353357
if (result != MA_SUCCESS) {
354358
return false;
355359
}
356360

357-
// Decoder will output ma_format_f32.
358-
// We need to use the data converter if:
359-
// 1. The sample rate needs to be changed.
360-
// 2. The audio is not already mono (decoder.outputChannels != 1).
361-
bool needs_resampling = (decoder.outputSampleRate != (ma_uint32)target_sampler_rate);
362-
bool needs_channel_mixing = (decoder.outputChannels != 1);
363-
364-
if (!needs_resampling && !needs_channel_mixing) {
365-
// Already target sample rate, already mono, and decoder is outputting f32. Direct read.
366-
ma_uint64 frame_count_total;
367-
result = ma_decoder_get_length_in_pcm_frames(&decoder, &frame_count_total);
368-
if (result != MA_SUCCESS) {
369-
ma_decoder_uninit(&decoder);
370-
return false;
371-
}
372-
373-
pcmf32_mono.resize(frame_count_total); // Mono, so frames == samples
374-
ma_uint64 frames_read = 0;
375-
result = ma_decoder_read_pcm_frames(&decoder, pcmf32_mono.data(), frame_count_total, &frames_read);
376-
if (result != MA_SUCCESS || frames_read != frame_count_total) {
377-
ma_decoder_uninit(&decoder);
378-
return false;
379-
}
380-
} else {
381-
// Resampling and/or channel mixing is needed.
382-
ma_data_converter_config data_converter_config = ma_data_converter_config_init_default();
383-
data_converter_config.formatIn = decoder.outputFormat; // This will be ma_format_f32
384-
data_converter_config.formatOut = ma_format_f32; // Output is also f32
385-
data_converter_config.channelsIn = decoder.outputChannels;
386-
data_converter_config.channelsOut = 1; // MONO output
387-
data_converter_config.sampleRateIn = decoder.outputSampleRate;
388-
data_converter_config.sampleRateOut = (ma_uint32)target_sampler_rate;
389-
data_converter_config.resampling.algorithm = ma_resample_algorithm_linear; // Or other algorithm
390-
391-
ma_data_converter data_converter;
392-
result = ma_data_converter_init(&data_converter_config, NULL, &data_converter);
393-
if (result != MA_SUCCESS) {
394-
ma_decoder_uninit(&decoder);
395-
return false;
396-
}
397-
398-
ma_uint64 total_frames_expected_from_decoder;
399-
result = ma_decoder_get_length_in_pcm_frames(&decoder, &total_frames_expected_from_decoder);
400-
if (result != MA_SUCCESS) {
401-
ma_data_converter_uninit(&data_converter, NULL);
402-
ma_decoder_uninit(&decoder);
403-
return false;
404-
}
405-
406-
double resample_ratio = (double)target_sampler_rate / decoder.outputSampleRate;
407-
// Reserve for mono output
408-
pcmf32_mono.reserve(static_cast<size_t>(total_frames_expected_from_decoder * resample_ratio * 1.1) + 1);
409-
410-
// Buffer to hold data read from the decoder (multi-channel, original sample rate, f32 format)
411-
const ma_uint64 DECODE_BUFFER_SIZE_FRAMES = 1024;
412-
std::vector<float> temp_decode_buffer(DECODE_BUFFER_SIZE_FRAMES * decoder.outputChannels);
413-
414-
while (true) {
415-
ma_uint64 frames_decoded_this_iteration = 0;
416-
result = ma_decoder_read_pcm_frames(&decoder, temp_decode_buffer.data(), DECODE_BUFFER_SIZE_FRAMES, &frames_decoded_this_iteration);
417-
418-
if (result != MA_SUCCESS && result != MA_AT_END) {
419-
ma_data_converter_uninit(&data_converter, NULL);
420-
ma_decoder_uninit(&decoder);
421-
return false;
422-
}
423-
424-
if (frames_decoded_this_iteration == 0 && result == MA_AT_END) { // Ensure we process the last bit if MA_AT_END was from previous read
425-
break;
426-
}
427-
428-
ma_uint64 frame_count_in = frames_decoded_this_iteration;
429-
ma_uint64 frame_count_out_capacity;
430-
431-
result = ma_data_converter_get_expected_output_frame_count(&data_converter, frame_count_in, &frame_count_out_capacity);
432-
if (result != MA_SUCCESS) {
433-
ma_data_converter_uninit(&data_converter, NULL);
434-
ma_decoder_uninit(&decoder);
435-
return false;
436-
}
437-
438-
size_t current_pcmf32_sample_offset = pcmf32_mono.size();
439-
// Resize for mono output (channelsOut is 1)
440-
pcmf32_mono.resize(current_pcmf32_sample_offset + frame_count_out_capacity * data_converter.channelsOut);
441-
442-
ma_uint64 frames_actually_output = frame_count_out_capacity;
443-
444-
result = ma_data_converter_process_pcm_frames(
445-
&data_converter,
446-
temp_decode_buffer.data(),
447-
&frame_count_in,
448-
pcmf32_mono.data() + current_pcmf32_sample_offset,
449-
&frames_actually_output
450-
);
451-
452-
if (result != MA_SUCCESS) {
453-
ma_data_converter_uninit(&data_converter, NULL);
454-
ma_decoder_uninit(&decoder);
455-
return false;
456-
}
457-
458-
// Adjust size to actual frames output (mono)
459-
pcmf32_mono.resize(current_pcmf32_sample_offset + frames_actually_output * data_converter.channelsOut);
361+
ma_uint64 frame_count;
362+
ma_uint64 frames_read;
363+
result = ma_decoder_get_length_in_pcm_frames(&decoder, &frame_count);
364+
if (result != MA_SUCCESS) {
365+
ma_decoder_uninit(&decoder);
366+
return false;
367+
}
460368

461-
if (result == MA_AT_END) {
462-
if (frames_decoded_this_iteration == 0 || frame_count_in == 0) break; // No more input frames processed or decoded
463-
}
464-
}
465-
ma_data_converter_uninit(&data_converter, NULL);
369+
pcmf32_mono.resize(frame_count);
370+
result = ma_decoder_read_pcm_frames(&decoder, pcmf32_mono.data(), frame_count, &frames_read);
371+
if (result != MA_SUCCESS || frames_read != frame_count) {
372+
ma_decoder_uninit(&decoder);
373+
return false;
466374
}
467375

468376
ma_decoder_uninit(&decoder);

tools/mtmd/mtmd-audio.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@ extern bool preprocess_audio(
4242

4343

4444

45-
namespace wav_utils {
45+
namespace audio_helpers {
4646

47-
extern bool is_wav_buffer(const std::string buf);
47+
extern bool is_audio_file(const char * buf, size_t len);
4848

49-
extern bool read_wav_from_buf(const unsigned char * buf_in, size_t len, int target_sampler_rate, std::vector<float> & pcmf32_mono);
49+
extern bool decode_audio_from_buf(const unsigned char * buf_in, size_t len, int target_sampler_rate, std::vector<float> & pcmf32_mono);
5050

51-
} // namespace wav_utils
51+
} // namespace audio_helpers
5252

5353

5454

tools/mtmd/mtmd.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -642,10 +642,9 @@ bool mtmd_support_audio(mtmd_context * ctx) {
642642
// whichever library they want, and then use mtmd_bitmap_init() to create bitmap
643643

644644
mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len) {
645-
if (len > 32 && wav_utils::is_wav_buffer(std::string((const char *)buf, 32))) {
646-
// WAV audio file
645+
if (audio_helpers::is_audio_file((const char *)buf, len)) {
647646
std::vector<float> pcmf32;
648-
if (!wav_utils::read_wav_from_buf(buf, len, COMMON_SAMPLE_RATE, pcmf32)) {
647+
if (!audio_helpers::decode_audio_from_buf(buf, len, COMMON_SAMPLE_RATE, pcmf32)) {
649648
LOG_ERR("Unable to read WAV audio file from buffer\n");
650649
return nullptr;
651650
}

tools/mtmd/mtmd.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,14 +205,16 @@ MTMD_API float * mtmd_get_output_embd(mtmd_context * ctx);
205205
//
206206

207207
// helper function to construct a mtmd_bitmap from a file
208+
// it calls mtmd_helper_bitmap_init_from_buf() internally
208209
// returns nullptr on failure
209210
// this function is thread-safe
210211
MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_file(const char * fname);
211212

212213
// helper function to construct a mtmd_bitmap from a buffer containing a file
213214
// supported formats:
214-
// image: format supported by stb_image (jpg, png, bmp, gif, etc.)
215-
// audio: wav
215+
// image: formats supported by stb_image: jpg, png, bmp, gif, etc.
216+
// audio: formats supported by miniaudio: wav, mp3, flac
217+
// note: audio files will be auto-detected based on magic bytes
216218
// returns nullptr on failure
217219
// this function is thread-safe
218220
MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(const unsigned char * buf, size_t len);

0 commit comments

Comments
 (0)