Skip to content

Commit 544cece

Browse files
Include audio preprocessing for raw audio tensor (#13951)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #13873 by @jackzhxng ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/jackzhxng/37/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/jackzhxng/37/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/jackzhxng/36/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/jackzhxng/37/orig @diff-train-skip-merge --------- Co-authored-by: Jack Zhang <[email protected]>
1 parent 497f59c commit 544cece

File tree

1 file changed

+170
-27
lines changed

1 file changed

+170
-27
lines changed

examples/models/voxtral/multimodal.cpp

Lines changed: 170 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212

1313
#include <gflags/gflags.h>
1414

15+
#include <executorch/extension/module/module.h>
16+
#include <executorch/extension/tensor/tensor_ptr_maker.h>
17+
#include <executorch/runtime/core/evalue.h>
18+
1519
#include <executorch/extension/llm/runner/audio.h>
1620
#include <executorch/extension/llm/runner/image.h>
1721
#include <executorch/extension/llm/runner/llm_runner_helper.h>
@@ -36,6 +40,11 @@ DEFINE_string(prompt, "What is happening in this audio?", "Text prompt.");
3640

3741
DEFINE_string(audio_path, "", "Path to input audio file.");
3842

43+
DEFINE_string(
44+
processor_path,
45+
"",
46+
"Path to processor .pte file for raw audio processing.");
47+
3948
DEFINE_double(
4049
temperature,
4150
0.8f,
@@ -50,10 +59,13 @@ DEFINE_bool(warmup, false, "Whether to run a warmup run.");
5059

5160
namespace {
5261

62+
using ::executorch::extension::from_blob;
63+
using ::executorch::extension::Module;
5364
using ::executorch::extension::llm::Image;
5465
using ::executorch::extension::llm::make_image_input;
5566
using ::executorch::extension::llm::make_text_input;
5667
using ::executorch::extension::llm::MultimodalInput;
68+
using ::executorch::runtime::EValue;
5769

5870
bool ends_with(const std::string& str, const std::string& suffix) {
5971
return str.size() >= suffix.size() &&
@@ -74,55 +86,185 @@ bool ends_with(const std::string& str, const std::string& suffix) {
7486
*/
7587
MultimodalInput loadPreprocessedAudio(const std::string& audio_path) {
7688
std::ifstream f(audio_path, std::ios::binary | std::ios::ate);
89+
if (!f.is_open()) {
90+
ET_LOG(Error, "Failed to open audio file: %s", audio_path.c_str());
91+
throw std::runtime_error("Failed to open audio file");
92+
}
93+
94+
std::size_t n_floats = f.tellg() / sizeof(float);
95+
f.seekg(0, std::ios::beg);
96+
7797
int32_t n_bins = 128;
7898
int32_t n_frames = 3000;
79-
std::size_t n_floats =
80-
f.tellg() / sizeof(float); // Number of floats in the audio file.
81-
f.seekg(0, std::ios::beg);
99+
82100
int32_t batch_size = ceil(
83101
n_floats /
84102
(n_bins * n_frames)); // Batch in increments of n_frames, rounding up.
85-
std::vector<float> audio_data(batch_size * n_bins * n_frames);
86-
f.read(
87-
reinterpret_cast<char*>(audio_data.data()),
88-
audio_data.size() * sizeof(float));
89103

90-
ET_LOG(Info, "audio_data len = %d", audio_data.size());
104+
ET_LOG(Info, "audio_data len = %zu", n_floats);
91105

106+
// Create Audio multimodal input
92107
auto audio = std::make_unique<::executorch::extension::llm::Audio>();
93108
audio->batch_size = batch_size;
94109
audio->n_bins = n_bins;
95110
audio->n_frames = n_frames;
96-
audio->data.resize(audio_data.size() * sizeof(float));
97-
std::memcpy(
98-
audio->data.data(), audio_data.data(), audio_data.size() * sizeof(float));
111+
audio->data.resize(n_floats * sizeof(float));
112+
f.read(reinterpret_cast<char*>(audio->data.data()), n_floats * sizeof(float));
113+
f.close();
99114
return ::executorch::extension::llm::make_audio_input(std::move(*audio));
100115
}
101116

102117
/**
103-
* @brief Processes audio files for multimodal input
118+
* @brief Loads a .bin file into a tensor and processes it using a .pte
119+
* processor
104120
*
105-
* Dispatches audio file processing based on file extension:
106-
* - .bin files: Loads preprocessed mel spectrogram features directly
107-
* - .wav/.mp3 files: Currently unsupported, throws runtime_error
121+
* This function loads raw audio data from a .bin file (similar to
122+
* loadPreprocessedAudio), creates a tensor from it, and then passes it through
123+
* a processor module loaded from a .pte file to generate processed audio
124+
* features.
108125
*
109-
* This function provides a interface for different audio input formats
110-
* and can be extended to support raw audio processing in the future.
126+
* @param audio_path Path to the .bin audio file
127+
* @param processor_path Path to the .pte processor file
128+
* @return MultimodalInput containing the processed audio data
129+
* @throws std::runtime_error if file loading or processing fails
130+
*/
131+
MultimodalInput processRawAudioFile(
132+
const std::string& audio_path,
133+
const std::string& processor_path) {
134+
if (processor_path.empty()) {
135+
ET_LOG(Error, "Processor path is required for raw audio processing");
136+
throw std::runtime_error(
137+
"Processor path is required for raw audio processing");
138+
}
139+
140+
// Load the audio processor .pte.
141+
std::unique_ptr<Module> processor_module;
142+
try {
143+
processor_module =
144+
std::make_unique<Module>(processor_path, Module::LoadMode::File);
145+
auto load_error = processor_module->load();
146+
if (load_error != ::executorch::runtime::Error::Ok) {
147+
ET_LOG(
148+
Error,
149+
"Failed to load processor module from: %s",
150+
processor_path.c_str());
151+
throw std::runtime_error("Failed to load processor module");
152+
}
153+
} catch (const std::exception& e) {
154+
ET_LOG(Error, "Exception while loading processor module: %s", e.what());
155+
throw std::runtime_error("Exception while loading processor module");
156+
}
157+
158+
// Load the audio data from file.
159+
std::ifstream f(audio_path, std::ios::binary | std::ios::ate);
160+
if (!f.is_open()) {
161+
ET_LOG(Error, "Failed to open audio file: %s", audio_path.c_str());
162+
throw std::runtime_error("Failed to open audio file");
163+
}
164+
165+
std::size_t n_floats = f.tellg() / sizeof(float);
166+
f.seekg(0, std::ios::beg);
167+
168+
std::vector<float> audio_data(n_floats);
169+
f.read(
170+
reinterpret_cast<char*>(audio_data.data()),
171+
audio_data.size() * sizeof(float));
172+
f.close();
173+
174+
ET_LOG(
175+
Info, "Loaded .bin file: %s, %zu floats", audio_path.c_str(), n_floats);
176+
177+
// Execute the processor
178+
std::vector<executorch::aten::SizesType> tensor_shape = {
179+
static_cast<executorch::aten::SizesType>(audio_data.size())};
180+
auto input_tensor = from_blob(
181+
audio_data.data(), tensor_shape, ::executorch::aten::ScalarType::Float);
182+
183+
ET_LOG(Info, "Processing audio through processor module...");
184+
auto result = processor_module->execute("forward", input_tensor);
185+
if (!result.ok()) {
186+
ET_LOG(Error, "Failed to execute processor's forward method");
187+
throw std::runtime_error("Failed to execute processor forward method");
188+
}
189+
190+
auto outputs = result.get();
191+
if (outputs.empty()) {
192+
ET_LOG(Error, "Processor returned no outputs");
193+
throw std::runtime_error("Processor returned no outputs");
194+
}
195+
196+
// Extract processed audio features
197+
const auto& processed_tensor = outputs[0].toTensor();
198+
const float* processed_data = processed_tensor.const_data_ptr<float>();
199+
const auto& sizes = processed_tensor.sizes();
200+
201+
ET_LOG(
202+
Info,
203+
"Processed audio tensor shape: [%d, %d, %d]",
204+
static_cast<int>(sizes[0]),
205+
static_cast<int>(sizes[1]),
206+
static_cast<int>(sizes[2]));
207+
208+
// Create Audio multimodal input from processed features
209+
auto processed_audio =
210+
std::make_unique<::executorch::extension::llm::Audio>();
211+
processed_audio->batch_size =
212+
static_cast<int32_t>(sizes[0]); // Note: batching for s > 30 doesn't work
213+
// yet, so this will just be = 1.
214+
processed_audio->n_bins = static_cast<int32_t>(sizes[1]);
215+
processed_audio->n_frames =
216+
static_cast<int32_t>(sizes[2]); // And this will just be = 3000.
217+
218+
size_t total_elements = processed_audio->batch_size *
219+
processed_audio->n_bins * processed_audio->n_frames;
220+
processed_audio->data.resize(total_elements * sizeof(float));
221+
std::memcpy(
222+
processed_audio->data.data(),
223+
processed_data,
224+
total_elements * sizeof(float));
225+
226+
ET_LOG(
227+
Info,
228+
"Created processed Audio: batch_size=%d, n_bins=%d, n_frames=%d",
229+
processed_audio->batch_size,
230+
processed_audio->n_bins,
231+
processed_audio->n_frames);
232+
233+
return ::executorch::extension::llm::make_audio_input(
234+
std::move(*processed_audio));
235+
}
236+
237+
/**
238+
* @brief Processes audio files for multimodal input
239+
*
240+
* Dispatches audio file processing based on file extension and processor
241+
* availability:
242+
* - .bin files with processor: Loads raw audio from .bin and processes through
243+
* processor
244+
* - .bin files without processor: Loads preprocessed mel spectrogram features
245+
* directly
111246
*
112-
* @param audio_path Path to the audio file
247+
* @param audio_path Path to the audio file (.bin)
248+
* @param processor_path Path to the processor .pte file (optional)
113249
* @return MultimodalInput containing the processed audio data
114250
* @throws std::runtime_error if file format is unsupported or processing fails
115251
*/
116-
MultimodalInput processAudioFile(const std::string& audio_path) {
252+
MultimodalInput processAudioFile(
253+
const std::string& audio_path,
254+
const std::string& processor_path = "") {
117255
if (ends_with(audio_path, ".bin")) {
118-
// Current behavior - load preprocessed audio stored as a binary file.
119-
return loadPreprocessedAudio(audio_path);
120-
} else if (ends_with(audio_path, ".wav") || ends_with(audio_path, ".mp3")) {
121-
// New: Process raw audio files - unsupported for now
122-
ET_LOG(Error, "Raw audio file processing (.wav/.mp3) is not yet supported");
123-
throw std::runtime_error("Raw audio file processing not supported");
256+
if (!processor_path.empty()) {
257+
// Process raw audio from .bin file through the processor
258+
return processRawAudioFile(audio_path, processor_path);
259+
} else {
260+
// Load preprocessed audio stored as a binary file (existing behavior)
261+
return loadPreprocessedAudio(audio_path);
262+
}
124263
} else {
125-
ET_LOG(Error, "Unsupported audio file format: %s", audio_path.c_str());
264+
ET_LOG(
265+
Error,
266+
"Unsupported audio file format: %s (only .bin files are supported)",
267+
audio_path.c_str());
126268
throw std::runtime_error("Unsupported audio file format");
127269
}
128270
}
@@ -137,6 +279,7 @@ int32_t main(int32_t argc, char** argv) {
137279
const char* tokenizer_path = FLAGS_tokenizer_path.c_str();
138280
const char* prompt = FLAGS_prompt.c_str();
139281
const char* audio_path = FLAGS_audio_path.c_str();
282+
const char* processor_path = FLAGS_processor_path.c_str();
140283
float temperature = FLAGS_temperature;
141284
int32_t cpu_threads = FLAGS_cpu_threads;
142285
bool warmup = FLAGS_warmup;
@@ -180,7 +323,7 @@ int32_t main(int32_t argc, char** argv) {
180323
// Prepare inputs
181324
std::vector<MultimodalInput> inputs = {
182325
make_text_input("<s>[INST][BEGIN_AUDIO]"),
183-
processAudioFile(audio_path),
326+
processAudioFile(audio_path, processor_path),
184327
make_text_input(std::string(prompt) + "[/INST]"),
185328
};
186329

0 commit comments

Comments
 (0)